Spaces:
Sleeping
Sleeping
| import os | |
| import httpx | |
| import json | |
| from typing import AsyncGenerator, List, Dict | |
| from config import logger | |
| # ===== OpenAI ===== | |
| async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]: | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| logger.error("OpenAI API key not provided") | |
| yield "Error: OpenAI API key not provided." | |
| return | |
| messages = [] | |
| for msg in history: | |
| if msg.get("role") == "user": | |
| messages.append({"role": "user", "content": msg["content"]}) | |
| elif msg.get("role") == "assistant": | |
| messages.append({"role": "assistant", "content": msg["content"]}) | |
| messages.append({"role": "user", "content": query}) | |
| headers = { | |
| "Authorization": f"Bearer {openai_api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "gpt-3.5-turbo", | |
| "messages": messages, | |
| "stream": True | |
| } | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| async with client.stream("POST", "https://api.openai.com/v1/chat/completions", headers=headers, json=payload) as response: | |
| response.raise_for_status() | |
| buffer = "" | |
| async for chunk in response.aiter_text(): | |
| if chunk: | |
| buffer += chunk | |
| while "\n" in buffer: | |
| line, buffer = buffer.split("\n", 1) | |
| if line.startswith("data: "): | |
| data = line[6:] | |
| if data.strip() == "[DONE]": | |
| break | |
| if not data.strip(): | |
| continue | |
| try: | |
| json_data = json.loads(data) | |
| delta = json_data["choices"][0].get("delta", {}) | |
| if "content" in delta: | |
| yield delta["content"] | |
| except Exception as e: | |
| logger.error(f"OpenAI parse error: {e}") | |
| yield f"[OpenAI Error]: {e}" | |
| except Exception as e: | |
| logger.error(f"OpenAI API error: {e}") | |
| yield f"[OpenAI Error]: {e}" | |
| # ===== Anthropic ===== | |
| async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]: | |
| anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not anthropic_api_key: | |
| logger.error("Anthropic API key not provided") | |
| yield "Error: Anthropic API key not provided." | |
| return | |
| # --- Start: Message Cleaning for Anthropic --- | |
| # Anthropic requires messages to alternate roles, starting with 'user'. | |
| # Clean the history to ensure this format. | |
| cleaned_messages = [] | |
| last_role = None | |
| for msg in history: | |
| role = msg.get("role") | |
| content = msg.get("content") | |
| if not role or not content: | |
| continue # Skip invalid messages | |
| # If the last message was the same role, skip this one or combine (combining is more complex) | |
| if role == last_role: | |
| logger.warning(f"Skipping consecutive message with role: {role}") | |
| continue | |
| # If the first message is 'assistant', skip it | |
| if not cleaned_messages and role == "assistant": | |
| logger.warning("Skipping initial assistant message in history for Anthropic.") | |
| continue | |
| cleaned_messages.append({"role": role, "content": content}) | |
| last_role = role | |
| # Ensure the last message in history is 'assistant' before adding the new user query | |
| # If the history ends with 'user', we might have an issue or the model didn't respond last turn. | |
| # For simplicity, we'll just append the new user query. The API will validate the full list. | |
| # A more robust approach might require padding with an empty assistant message if history ends with user. | |
| # However, the core.py logic should ensure history alternates correctly. | |
| # The main cleaning needed is handling initial assistant messages and consecutive roles. | |
| # Append the current user query | |
| cleaned_messages.append({"role": "user", "content": query}) | |
| # Final check: Ensure the list starts with 'user' and alternates. | |
| # If after cleaning and adding the new query, the list is empty or starts with 'assistant', something is wrong. | |
| if not cleaned_messages or cleaned_messages[0].get("role") != "user": | |
| logger.error("Anthropic message cleaning resulted in invalid format.") | |
| yield "Error: Internal message formatting issue for Anthropic." | |
| return | |
| # --- End: Message Cleaning --- | |
| headers = { | |
| "x-api-key": anthropic_api_key, | |
| "anthropic-version": "2023-06-01", # Use a valid API version | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "claude-3-5-sonnet-20241022", # Ensure you are using a valid model name | |
| "max_tokens": 4096, # Increased max_tokens for potentially longer responses | |
| "messages": cleaned_messages, # Use the cleaned messages | |
| "stream": True | |
| } | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| async with client.stream("POST", "https://api.anthropic.com/v1/messages", headers=headers, json=payload) as response: | |
| response.raise_for_status() # Raise HTTPError for bad responses (like 400) | |
| buffer = "" | |
| async for chunk in response.aiter_text(): | |
| if chunk: | |
| buffer += chunk | |
| # Anthropic streaming sends JSON objects separated by newlines | |
| # Sometimes multiple objects are in one chunk | |
| while "\n" in buffer: | |
| line, buffer = buffer.split("\n", 1) | |
| if line.startswith("data: "): | |
| data = line[6:] | |
| if data.strip() == "[DONE]": | |
| break | |
| if not data.strip(): | |
| continue | |
| try: | |
| json_data = json.loads(data) | |
| # Check the type of event | |
| if json_data.get("type") == "content_block_delta" and "delta" in json_data: | |
| yield json_data["delta"].get("text", "") | |
| # Handle other event types if necessary (e.g., message_start, message_delta, message_stop) | |
| except json.JSONDecodeError: | |
| # If it's not a complete JSON line, keep buffering | |
| buffer = line + "\n" + buffer # Put the line back in buffer | |
| except Exception as e: | |
| logger.error(f"Anthropic parse error: {e}") | |
| yield f"[Anthropic Parse Error]: {e}" | |
| except httpx.HTTPStatusError as e: | |
| logger.error(f"Anthropic API HTTP error: {e.response.status_code} - {e.response.text}") | |
| yield f"[Anthropic API Error {e.response.status_code}]: {e.response.text}" | |
| except Exception as e: | |
| logger.error(f"Anthropic API error: {e}") | |
| yield f"[Anthropic Error]: {e}" | |
| # ===== Gemini ===== | |
| async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]: | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| if not gemini_api_key: | |
| logger.error("Gemini API key not provided") | |
| yield "Error: Gemini API key not provided." | |
| return | |
| history_text = "" | |
| for msg in history: | |
| if msg.get("role") == "user": | |
| history_text += f"User: {msg['content']}\n" | |
| elif msg.get("role") == "assistant": | |
| history_text += f"Assistant: {msg['content']}\n" | |
| full_prompt = f"{history_text}User: {query}\n" | |
| headers = {"Content-Type": "application/json"} | |
| payload = { | |
| "contents": [{"parts": [{"text": full_prompt}]}] | |
| } | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| async with client.stream( | |
| "POST", | |
| f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent?key={gemini_api_key}", | |
| headers=headers, | |
| json=payload | |
| ) as response: | |
| response.raise_for_status() | |
| buffer = "" | |
| async for chunk in response.aiter_text(): | |
| if not chunk.strip(): | |
| continue | |
| buffer += chunk | |
| try: | |
| json_data = json.loads(buffer.strip(", \n")) | |
| buffer = "" | |
| objects = json_data if isinstance(json_data, list) else [json_data] | |
| for obj in objects: | |
| candidates = obj.get("candidates", []) | |
| if candidates: | |
| parts = candidates[0].get("content", {}).get("parts", []) | |
| for part in parts: | |
| text = part.get("text", "") | |
| if text: | |
| yield text | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| logger.error(f"Gemini API error: {e}") | |
| yield f"[Gemini Error]: {e}" |