Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import AsyncGenerator, List, Dict, Tuple | |
| from config import logger | |
| from api import ask_openai, ask_anthropic, ask_gemini | |
| async def query_model( | |
| query: str, | |
| providers: List[str], | |
| history: List[Dict[str, str]] | |
| ) -> AsyncGenerator[ | |
| Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], | |
| None | |
| ]: | |
| logger.info(f"Processing query with providers: {providers}") | |
| openai_response = "" | |
| anthropic_response = "" | |
| gemini_response = "" | |
| openai_messages = [] | |
| anthropic_messages = [] | |
| gemini_messages = [] | |
| # Build message history for each provider | |
| for msg in history: | |
| if "user" in msg: | |
| openai_messages.append({"role": "user", "content": msg["user"]}) | |
| anthropic_messages.append({"role": "user", "content": msg["user"]}) | |
| gemini_messages.append({"role": "user", "content": msg["user"]}) | |
| if msg.get("openai"): | |
| openai_messages.append({"role": "assistant", "content": msg["openai"]}) | |
| if msg.get("anthropic"): | |
| anthropic_messages.append({"role": "assistant", "content": msg["anthropic"]}) | |
| if msg.get("gemini"): | |
| gemini_messages.append({"role": "assistant", "content": msg["gemini"]}) | |
| # Append the user query and prepare for assistant response | |
| if "OpenAI" in providers: | |
| openai_messages.append({"role": "user", "content": query}) | |
| openai_messages.append({"role": "assistant", "content": ""}) | |
| if "Anthropic" in providers: | |
| anthropic_messages.append({"role": "user", "content": query}) | |
| anthropic_messages.append({"role": "assistant", "content": ""}) | |
| if "Gemini" in providers: | |
| gemini_messages.append({"role": "user", "content": query}) | |
| gemini_messages.append({"role": "assistant", "content": ""}) | |
| # Yield initial state with user query | |
| logger.info(f"Yielding initial state with user query: {query}") | |
| yield "", openai_messages, anthropic_messages, gemini_messages, history | |
| tasks = [] | |
| if "OpenAI" in providers: | |
| tasks.append(("OpenAI", ask_openai(query, history), openai_response, openai_messages)) | |
| if "Anthropic" in providers: | |
| tasks.append(("Anthropic", ask_anthropic(query, history), anthropic_response, anthropic_messages)) | |
| if "Gemini" in providers: | |
| tasks.append(("Gemini", ask_gemini(query, history), gemini_response, gemini_messages)) | |
| async def collect_chunks( | |
| provider: str, | |
| generator: AsyncGenerator[str, None], | |
| response: str, | |
| messages: List[Dict[str, str]] | |
| ) -> AsyncGenerator[Tuple[str, str, List[Dict[str, str]]], None]: | |
| async for chunk in generator: | |
| response += chunk | |
| messages[-1] = {"role": "assistant", "content": response} | |
| yield provider, response, messages | |
| generator_states = [(provider, collect_chunks(provider, gen, resp, msgs), None) for provider, gen, resp, msgs in tasks] | |
| active_generators = generator_states[:] | |
| while active_generators: | |
| tasks_to_wait = [] | |
| new_generator_states = [] | |
| for provider, gen, active_task in active_generators: | |
| if active_task is None or active_task.done(): | |
| try: | |
| task = asyncio.create_task(gen.__anext__()) | |
| new_generator_states.append((provider, gen, task)) | |
| tasks_to_wait.append(task) | |
| logger.debug(f"Created task for {provider}") | |
| except StopAsyncIteration: | |
| logger.info(f"Generator for {provider} completed") | |
| continue | |
| else: | |
| new_generator_states.append((provider, gen, active_task)) | |
| tasks_to_wait.append(active_task) | |
| if not tasks_to_wait: | |
| break | |
| done, _ = await asyncio.wait(tasks_to_wait, return_when=asyncio.FIRST_COMPLETED) | |
| for provider, gen, task in new_generator_states: | |
| if task in done: | |
| try: | |
| provider, response, messages = task.result() | |
| if provider == "OpenAI": | |
| openai_response = response | |
| openai_messages = messages | |
| elif provider == "Anthropic": | |
| anthropic_response = response | |
| anthropic_messages = messages | |
| elif provider == "Gemini": | |
| gemini_response = response | |
| gemini_messages = messages | |
| logger.info(f"Yielding update for {provider}: {response[:50]}...") | |
| yield "", openai_messages, anthropic_messages, gemini_messages, history | |
| new_generator_states[new_generator_states.index((provider, gen, task))] = (provider, gen, None) | |
| except StopAsyncIteration: | |
| logger.info(f"Generator for {provider} completed") | |
| new_generator_states.remove((provider, gen, task)) | |
| else: | |
| if (provider, gen, task) not in new_generator_states: | |
| new_generator_states.append((provider, gen, task)) | |
| active_generators = new_generator_states | |
| updated_history = history + [{ | |
| "user": query, | |
| "openai": openai_response.strip() if openai_response else "", | |
| "anthropic": anthropic_response.strip() if anthropic_response else "", | |
| "gemini": gemini_response.strip() if gemini_response else "" | |
| }] | |
| logger.info(f"Updated history: {updated_history}") | |
| yield "", openai_messages, anthropic_messages, gemini_messages, updated_history | |
| async def submit_query( | |
| query: str, | |
| providers: List[str], | |
| history: List[Dict[str, str]] | |
| ) -> AsyncGenerator[ | |
| Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], | |
| None | |
| ]: | |
| if not query.strip(): | |
| msg = {"role": "assistant", "content": "Please enter a query."} | |
| yield "", [msg], [msg], [msg], history | |
| return | |
| if not providers: | |
| msg = {"role": "assistant", "content": "Please select at least one provider."} | |
| yield "", [msg], [msg], [msg], history | |
| return | |
| async for _, openai_msgs, anthropic_msgs, gemini_msgs, updated_history in query_model(query, providers, history): | |
| logger.info(f"Submitting update to UI: OpenAI: {openai_msgs[-1]['content'][:50] if openai_msgs else ''}, " | |
| f"Anthropic: {anthropic_msgs[-1]['content'][:50] if anthropic_msgs else ''}, " | |
| f"Gemini: {gemini_msgs[-1]['content'][:50] if gemini_msgs else ''}") | |
| yield "", openai_msgs, anthropic_msgs, gemini_msgs, updated_history | |
| def clear_history(): | |
| logger.info("Clearing history") | |
| return [], [], [], [] |