Spaces:
Sleeping
Sleeping
File size: 9,772 Bytes
3bc4707 af65410 3bc4707 467c948 3bc4707 231773b 3bc4707 467c948 3bc4707 467c948 3bc4707 467c948 3bc4707 467c948 3bc4707 467c948 3bc4707 f45d8b2 3bc4707 f45d8b2 3bc4707 f45d8b2 3bc4707 f45d8b2 3bc4707 467c948 3bc4707 f45d8b2 3bc4707 f45d8b2 3bc4707 f45d8b2 467c948 f45d8b2 3bc4707 467c948 f45d8b2 3bc4707 467c948 3bc4707 af65410 3bc4707 231773b 467c948 3bc4707 467c948 3bc4707 467c948 3bc4707 467c948 3bc4707 467c948 231773b 3bc4707 467c948 231773b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import httpx
import json
from typing import AsyncGenerator, List, Dict
from config import logger
# ===== OpenAI =====
async def ask_openai(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
logger.error("OpenAI API key not provided")
yield "Error: OpenAI API key not provided."
return
messages = []
for msg in history:
if msg.get("role") == "user":
messages.append({"role": "user", "content": msg["content"]})
elif msg.get("role") == "assistant":
messages.append({"role": "assistant", "content": msg["content"]})
messages.append({"role": "user", "content": query})
headers = {
"Authorization": f"Bearer {openai_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": "gpt-3.5-turbo",
"messages": messages,
"stream": True
}
try:
async with httpx.AsyncClient() as client:
async with client.stream("POST", "https://api.openai.com/v1/chat/completions", headers=headers, json=payload) as response:
response.raise_for_status()
buffer = ""
async for chunk in response.aiter_text():
if chunk:
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
if not data.strip():
continue
try:
json_data = json.loads(data)
delta = json_data["choices"][0].get("delta", {})
if "content" in delta:
yield delta["content"]
except Exception as e:
logger.error(f"OpenAI parse error: {e}")
yield f"[OpenAI Error]: {e}"
except Exception as e:
logger.error(f"OpenAI API error: {e}")
yield f"[OpenAI Error]: {e}"
# ===== Anthropic =====
async def ask_anthropic(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
if not anthropic_api_key:
logger.error("Anthropic API key not provided")
yield "Error: Anthropic API key not provided."
return
# --- Start: Message Cleaning for Anthropic ---
# Anthropic requires messages to alternate roles, starting with 'user'.
# Clean the history to ensure this format.
cleaned_messages = []
last_role = None
for msg in history:
role = msg.get("role")
content = msg.get("content")
if not role or not content:
continue # Skip invalid messages
# If the last message was the same role, skip this one or combine (combining is more complex)
if role == last_role:
logger.warning(f"Skipping consecutive message with role: {role}")
continue
# If the first message is 'assistant', skip it
if not cleaned_messages and role == "assistant":
logger.warning("Skipping initial assistant message in history for Anthropic.")
continue
cleaned_messages.append({"role": role, "content": content})
last_role = role
# Ensure the last message in history is 'assistant' before adding the new user query
# If the history ends with 'user', we might have an issue or the model didn't respond last turn.
# For simplicity, we'll just append the new user query. The API will validate the full list.
# A more robust approach might require padding with an empty assistant message if history ends with user.
# However, the core.py logic should ensure history alternates correctly.
# The main cleaning needed is handling initial assistant messages and consecutive roles.
# Append the current user query
cleaned_messages.append({"role": "user", "content": query})
# Final check: Ensure the list starts with 'user' and alternates.
# If after cleaning and adding the new query, the list is empty or starts with 'assistant', something is wrong.
if not cleaned_messages or cleaned_messages[0].get("role") != "user":
logger.error("Anthropic message cleaning resulted in invalid format.")
yield "Error: Internal message formatting issue for Anthropic."
return
# --- End: Message Cleaning ---
headers = {
"x-api-key": anthropic_api_key,
"anthropic-version": "2023-06-01", # Use a valid API version
"Content-Type": "application/json"
}
payload = {
"model": "claude-3-5-sonnet-20241022", # Ensure you are using a valid model name
"max_tokens": 4096, # Increased max_tokens for potentially longer responses
"messages": cleaned_messages, # Use the cleaned messages
"stream": True
}
try:
async with httpx.AsyncClient() as client:
async with client.stream("POST", "https://api.anthropic.com/v1/messages", headers=headers, json=payload) as response:
response.raise_for_status() # Raise HTTPError for bad responses (like 400)
buffer = ""
async for chunk in response.aiter_text():
if chunk:
buffer += chunk
# Anthropic streaming sends JSON objects separated by newlines
# Sometimes multiple objects are in one chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.startswith("data: "):
data = line[6:]
if data.strip() == "[DONE]":
break
if not data.strip():
continue
try:
json_data = json.loads(data)
# Check the type of event
if json_data.get("type") == "content_block_delta" and "delta" in json_data:
yield json_data["delta"].get("text", "")
# Handle other event types if necessary (e.g., message_start, message_delta, message_stop)
except json.JSONDecodeError:
# If it's not a complete JSON line, keep buffering
buffer = line + "\n" + buffer # Put the line back in buffer
except Exception as e:
logger.error(f"Anthropic parse error: {e}")
yield f"[Anthropic Parse Error]: {e}"
except httpx.HTTPStatusError as e:
logger.error(f"Anthropic API HTTP error: {e.response.status_code} - {e.response.text}")
yield f"[Anthropic API Error {e.response.status_code}]: {e.response.text}"
except Exception as e:
logger.error(f"Anthropic API error: {e}")
yield f"[Anthropic Error]: {e}"
# ===== Gemini =====
async def ask_gemini(query: str, history: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
logger.error("Gemini API key not provided")
yield "Error: Gemini API key not provided."
return
history_text = ""
for msg in history:
if msg.get("role") == "user":
history_text += f"User: {msg['content']}\n"
elif msg.get("role") == "assistant":
history_text += f"Assistant: {msg['content']}\n"
full_prompt = f"{history_text}User: {query}\n"
headers = {"Content-Type": "application/json"}
payload = {
"contents": [{"parts": [{"text": full_prompt}]}]
}
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent?key={gemini_api_key}",
headers=headers,
json=payload
) as response:
response.raise_for_status()
buffer = ""
async for chunk in response.aiter_text():
if not chunk.strip():
continue
buffer += chunk
try:
json_data = json.loads(buffer.strip(", \n"))
buffer = ""
objects = json_data if isinstance(json_data, list) else [json_data]
for obj in objects:
candidates = obj.get("candidates", [])
if candidates:
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
text = part.get("text", "")
if text:
yield text
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"Gemini API error: {e}")
yield f"[Gemini Error]: {e}" |