|
|
import socket |
|
|
from uuid import uuid4 |
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import HTMLResponse |
|
|
from pydantic import BaseModel |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from asyncio import TimeoutError |
|
|
import asyncio |
|
|
from typing import Optional |
|
|
import requests |
|
|
import uvicorn |
|
|
import shutil |
|
|
import datetime |
|
|
import logging |
|
|
from logging.handlers import RotatingFileHandler |
|
|
import time |
|
|
from typing import List, Dict, Optional |
|
|
import json |
|
|
import os |
|
|
import psutil |
|
|
import sys |
|
|
from typing import Dict |
|
|
import tempfile |
|
|
import re |
|
|
import requests |
|
|
import random |
|
|
import aiohttp |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
USER_AGENTS = [ |
|
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36", |
|
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36", |
|
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36" |
|
|
] |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
request_counter = { |
|
|
"analyze": 0, |
|
|
"compareAnalyze": 0, |
|
|
"total": 0 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=[ |
|
|
"http://*", |
|
|
"https://*" |
|
|
], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class AskRequest(BaseModel): |
|
|
prompt: str |
|
|
model: str = "GEMINI" |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def health_check(): |
|
|
return { |
|
|
"health": "ok", |
|
|
"timestamp": datetime.datetime.now().isoformat(), |
|
|
"service": "AI API Forwarding Service", |
|
|
"version": "1.0" |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/ask") |
|
|
async def forward_ask(request: AskRequest): |
|
|
request_counter["total"] += 1 |
|
|
try: |
|
|
response = requests.post( |
|
|
"http://s5.serv00.com:9081/ask", |
|
|
headers={'Content-Type': 'application/json'}, |
|
|
json=request.dict() |
|
|
) |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/analyze") |
|
|
async def forward_analyze(image: UploadFile = File(...), model: str = Form(...)): |
|
|
request_counter["analyze"] += 1 |
|
|
request_counter["total"] += 1 |
|
|
try: |
|
|
files = {'image': (image.filename, image.file, image.content_type)} |
|
|
data = {'model': model} |
|
|
response = requests.post( |
|
|
"http://s5.serv00.com:9081/analyze", |
|
|
files=files, |
|
|
data=data |
|
|
) |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/compareAnalyze") |
|
|
async def forward_compare_analyze(image: UploadFile = File(...)): |
|
|
request_counter["compareAnalyze"] += 1 |
|
|
request_counter["total"] += 1 |
|
|
try: |
|
|
files = {'image': (image.filename, image.file, image.content_type)} |
|
|
response = requests.post( |
|
|
"http://s5.serv00.com:9081/compareAnalyze", |
|
|
files=files |
|
|
) |
|
|
return response.json() |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/status") |
|
|
async def forward_status(): |
|
|
start_time = time.time() |
|
|
logger.info(f"Received status request at {datetime.datetime.now()}") |
|
|
logger.info(f"Current request counter: {request_counter}") |
|
|
|
|
|
try: |
|
|
logger.info("Attempting to contact upstream server...") |
|
|
response = requests.get("http://s5.serv00.com:9081/status") |
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
logger.info(f"Upstream server responded in {elapsed_time:.2f} seconds") |
|
|
logger.info(f"Response status code: {response.status_code}") |
|
|
logger.info(f"Response content: {response.text[:200]}...") |
|
|
|
|
|
return response.json() |
|
|
except Exception as e: |
|
|
logger.error(f"Error occurred: {str(e)}") |
|
|
logger.error(f"Error type: {type(e).__name__}") |
|
|
return { |
|
|
"status": "running", |
|
|
"requests": request_counter, |
|
|
"error": str(e), |
|
|
"timestamp": datetime.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
@app.get("/check", response_class=HTMLResponse) |
|
|
async def forward_check(): |
|
|
try: |
|
|
response = requests.get("http://s5.serv00.com:9081/check") |
|
|
return response.text |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Translation(BaseModel): |
|
|
translation: str |
|
|
type: str |
|
|
|
|
|
class Phrase(BaseModel): |
|
|
phrase: str |
|
|
translation: str |
|
|
|
|
|
class Word(BaseModel): |
|
|
word: str |
|
|
translations: List[dict] |
|
|
phrases: List[Phrase] = [] |
|
|
level: str = "" |
|
|
|
|
|
|
|
|
word_map: Dict[str, Word] = {} |
|
|
|
|
|
def get_level_from_filename(filename: str) -> str: |
|
|
|
|
|
match = re.match(r'\d+-(.+?)-顺序\.json', filename) |
|
|
return match.group(1) if match else "unknown" |
|
|
|
|
|
def init_word_map(): |
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
json_dir = os.path.join(current_dir, "json") |
|
|
stats = { |
|
|
"total_words": 0, |
|
|
"total_files": 0, |
|
|
"file_stats": {} |
|
|
} |
|
|
|
|
|
try: |
|
|
for filename in os.listdir(json_dir): |
|
|
if filename.endswith('.json'): |
|
|
try: |
|
|
level = get_level_from_filename(filename) |
|
|
with open(os.path.join(json_dir, filename), 'r', encoding='utf-8') as f: |
|
|
words = json.load(f) |
|
|
word_count = len(words) |
|
|
stats["total_words"] += word_count |
|
|
stats["total_files"] += 1 |
|
|
stats["file_stats"][filename] = word_count |
|
|
for word_data in words: |
|
|
|
|
|
if 'translations' not in word_data: |
|
|
word_data['translations'] = [{ |
|
|
'translation': word_data.get('translation', ''), |
|
|
'type': word_data.get('type', '') |
|
|
}] |
|
|
if 'phrases' not in word_data: |
|
|
word_data['phrases'] = [] |
|
|
|
|
|
word_data['level'] = level |
|
|
word = Word(**word_data) |
|
|
word_map[word.word.lower()] = word |
|
|
logger.info(f"Loaded {filename}: {word_count} words") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading {filename}: {str(e)}") |
|
|
continue |
|
|
|
|
|
logger.info(f"Dictionary initialization complete:") |
|
|
logger.info(f"Total files processed: {stats['total_files']}") |
|
|
logger.info(f"Total words loaded: {stats['total_words']}") |
|
|
return stats |
|
|
except Exception as e: |
|
|
logger.error(f"Fatal error in init_word_map: {str(e)}") |
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = os.path.join(tempfile.gettempdir(), "flash_api_cache") |
|
|
CACHE_FILE = os.path.join(CACHE_DIR, "ai_translation_cache.json") |
|
|
ai_cache: Dict[str, dict] = {} |
|
|
|
|
|
|
|
|
def save_cache(): |
|
|
try: |
|
|
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
with open(CACHE_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(ai_cache, f, ensure_ascii=False, indent=2) |
|
|
logger.info(f"Cache saved to: {CACHE_FILE}") |
|
|
except PermissionError as pe: |
|
|
logger.error(f"Permission denied writing to cache: {pe}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving cache: {e}") |
|
|
|
|
|
def load_cache(): |
|
|
global ai_cache |
|
|
try: |
|
|
if os.path.exists(CACHE_FILE): |
|
|
with open(CACHE_FILE, 'r', encoding='utf-8') as f: |
|
|
ai_cache = json.load(f) |
|
|
logger.info(f"Loaded {len(ai_cache)} cached translations from: {CACHE_FILE}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading cache: {e}") |
|
|
ai_cache = {} |
|
|
|
|
|
|
|
|
@app.get("/translate/{word}") |
|
|
async def translate_word(word: str): |
|
|
start_time = time.time() |
|
|
logger.info(f"Translation request received for word: {word}") |
|
|
|
|
|
try: |
|
|
word = word.lower().strip() |
|
|
logger.debug(f"Processed word: {word}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Word not found in cache, calling AI API") |
|
|
|
|
|
logger.info("Word not found in map, falling back to AI translation") |
|
|
try: |
|
|
request = AskRequest( |
|
|
prompt=f'''翻译以下英文 |
|
|
{word} |
|
|
每行一个 格式参考,不要任何md格式,分别要有音标,单词属性(名词,动词,形容词),中文翻译,英文解析,例句,近义词,反义词,词性 |
|
|
格式参考: |
|
|
hello:/həˈləʊ/| n. vt. int.|你好,问候语,|例句:Hello, how are you? 你好,你好吗?|近义词:hi, hey, |反义词:sick, bad.''', |
|
|
model="GEMINI" |
|
|
) |
|
|
logger.debug(f"AI Request: {request}") |
|
|
|
|
|
result = await forward_ask(request) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"AI Response: {result}") |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
logger.info(f"AI translation completed in {elapsed:.2f}s") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"AI translation error: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Translation error: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
try: |
|
|
|
|
|
temp_dir = os.path.join(tempfile.gettempdir(), "flash_api_cache") |
|
|
if os.path.exists(temp_dir): |
|
|
shutil.rmtree(temp_dir) |
|
|
logger.info(f"Cleaned up temp directory: {temp_dir}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error cleaning temp files: {e}") |
|
|
|
|
|
def cleanup_cache(): |
|
|
global ai_cache |
|
|
ai_cache = {} |
|
|
logger.info("Cache cleared") |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
|
|
|
|
|
|
cleanup_temp_files() |
|
|
cleanup_cache() |
|
|
logger.info(f"Memory usage after init: {get_memory_usage()}") |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
|
|
|
cleanup_temp_files() |
|
|
cleanup_cache() |
|
|
logger.info("Application shutdown cleanup complete") |
|
|
def get_memory_usage(): |
|
|
process = psutil.Process() |
|
|
memory_info = process.memory_info() |
|
|
|
|
|
|
|
|
system = psutil.virtual_memory() |
|
|
|
|
|
return { |
|
|
"process": { |
|
|
"rss": f"{memory_info.rss / 1024 / 1024:.2f} MB", |
|
|
"rss_percent": f"{memory_info.rss / system.total * 100:.2f}%", |
|
|
"vms": f"{memory_info.vms / 1024 / 1024:.2f} MB", |
|
|
"vms_percent": f"{memory_info.vms / system.total * 100:.2f}%" |
|
|
}, |
|
|
"system": { |
|
|
"total": f"{system.total / 1024 / 1024:.2f} MB", |
|
|
"available": f"{system.available / 1024 / 1024:.2f} MB", |
|
|
"used_percent": f"{system.percent:.2f}%" |
|
|
}, |
|
|
"word_map": { |
|
|
"entries": len(word_map), |
|
|
"memory": f"{sys.getsizeof(word_map) / 1024 / 1024:.2f} MB", |
|
|
"memory_percent": f"{sys.getsizeof(word_map) / system.total * 100:.4f}%" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/memory") |
|
|
async def memory_status(): |
|
|
return get_memory_usage() |
|
|
|
|
|
@app.get("/ip") |
|
|
async def get_ip(): |
|
|
|
|
|
hostname = socket.gethostname() |
|
|
internal_ip = socket.gethostbyname(hostname) |
|
|
|
|
|
|
|
|
try: |
|
|
external_ip = requests.get('https://api.ipify.org').text |
|
|
except: |
|
|
external_ip = "Unable to fetch external IP" |
|
|
|
|
|
return { |
|
|
"hostname": hostname, |
|
|
"internal_ip": internal_ip, |
|
|
"external_ip": external_ip, |
|
|
"timestamp": datetime.datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
@app.get("/proxy") |
|
|
async def proxy_request(url: str, request: Request): |
|
|
try: |
|
|
|
|
|
user_agent = random.choice(USER_AGENTS) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Proxy request received for: {url}") |
|
|
|
|
|
|
|
|
headers = { |
|
|
'User-Agent': user_agent, |
|
|
'Accept': 'application/json, text/plain, */*', |
|
|
'Accept-Language': 'en-US,en;q=0.9', |
|
|
'Origin': 'https://www.youtube.com', |
|
|
'Referer': 'https://www.youtube.com/', |
|
|
'Sec-Fetch-Dest': 'empty', |
|
|
'Sec-Fetch-Mode': 'cors', |
|
|
'Sec-Fetch-Site': 'same-site', |
|
|
'Connection': 'keep-alive' |
|
|
} |
|
|
|
|
|
|
|
|
timeout = aiohttp.ClientTimeout(total=10) |
|
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
async with session.get(url, headers=headers) as response: |
|
|
|
|
|
if response.status != 200: |
|
|
raise HTTPException( |
|
|
status_code=response.status, |
|
|
detail=f"HTTP error: {response.status}" |
|
|
) |
|
|
|
|
|
|
|
|
data = await response.json() |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Received youtube subtile data: {len(data)} bytes") |
|
|
|
|
|
|
|
|
|
|
|
if not data or 'events' not in data: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="Invalid subtitle data format" |
|
|
) |
|
|
|
|
|
return data |
|
|
|
|
|
except TimeoutError: |
|
|
raise HTTPException(status_code=408, detail="Request timeout") |
|
|
except Exception as e: |
|
|
logger.error(f"Proxy error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
API_KEY = os.environ.get('API_KEY') |
|
|
if not API_KEY: |
|
|
raise ValueError("API_KEY environment variable is required") |
|
|
|
|
|
MODEL_MAPPING = { |
|
|
"deepseek": "deepseek/deepseek-chat", |
|
|
"gpt-4o-mini": "openai/gpt-4o-mini", |
|
|
"gemini-flash-1.5": "google/gemini-flash-1.5", |
|
|
"deepseek-reasoner": "deepseek-reasoner", |
|
|
"minimax-01": "minimax/minimax-01" |
|
|
} |
|
|
|
|
|
def verify_api_key(): |
|
|
auth_header = requests.request.headers.get('Authorization') |
|
|
if not auth_header: |
|
|
return False |
|
|
try: |
|
|
|
|
|
if auth_header.startswith('Bearer '): |
|
|
token = auth_header.split(' ')[1] |
|
|
else: |
|
|
token = auth_header |
|
|
return token == API_KEY |
|
|
except: |
|
|
return False |
|
|
|
|
|
def make_heck_request(question, session_id, messages, actual_model): |
|
|
previous_question = previous_answer = None |
|
|
if len(messages) >= 2: |
|
|
for i in range(len(messages)-2, -1, -1): |
|
|
if messages[i]["role"] == "user": |
|
|
previous_question = messages[i]["content"] |
|
|
if i+1 < len(messages) and messages[i+1]["role"] == "assistant": |
|
|
previous_answer = messages[i+1]["content"] |
|
|
break |
|
|
|
|
|
payload = { |
|
|
"model": actual_model, |
|
|
"question": question, |
|
|
"language": "Chinese", |
|
|
"sessionId": session_id, |
|
|
"previousQuestion": previous_question, |
|
|
"previousAnswer": previous_answer |
|
|
} |
|
|
|
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" |
|
|
} |
|
|
|
|
|
return requests.post( |
|
|
"https://gateway.aiapilab.com/api/ha/v1/chat", |
|
|
json=payload, |
|
|
headers=headers, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
def stream_response(question, session_id, messages, request_model, actual_model): |
|
|
resp = make_heck_request(question, session_id, messages, actual_model) |
|
|
is_answering = False |
|
|
|
|
|
for line in resp.iter_lines(): |
|
|
if line: |
|
|
line = line.decode('utf-8') |
|
|
if not line.startswith('data: '): |
|
|
continue |
|
|
|
|
|
content = line[6:].strip() |
|
|
|
|
|
if content == "[ANSWER_START]": |
|
|
is_answering = True |
|
|
chunk = { |
|
|
"id": session_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": request_model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {"role": "assistant"}, |
|
|
}] |
|
|
} |
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
continue |
|
|
|
|
|
if content == "[ANSWER_DONE]": |
|
|
chunk = { |
|
|
"id": session_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": request_model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
break |
|
|
|
|
|
if is_answering and content and not content.startswith("[RELATE_Q"): |
|
|
chunk = { |
|
|
"id": session_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(time.time()), |
|
|
"model": request_model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {"content": content}, |
|
|
}] |
|
|
} |
|
|
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" |
|
|
|
|
|
def normal_response(question, session_id, messages, request_model, actual_model): |
|
|
resp = make_heck_request(question, session_id, messages, actual_model) |
|
|
full_content = [] |
|
|
is_answering = False |
|
|
|
|
|
for line in resp.iter_lines(): |
|
|
if line: |
|
|
line = line.decode('utf-8') |
|
|
if line.startswith('data: '): |
|
|
content = line[6:].strip() |
|
|
if content == "[ANSWER_START]": |
|
|
is_answering = True |
|
|
elif content == "[ANSWER_DONE]": |
|
|
break |
|
|
elif is_answering: |
|
|
full_content.append(content) |
|
|
|
|
|
response = { |
|
|
"id": session_id, |
|
|
"object": "chat.completion", |
|
|
"created": int(time.time()), |
|
|
"model": request_model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": "".join(full_content) |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
} |
|
|
return response |
|
|
|
|
|
|
|
|
@app.get("/hf/v1/models") |
|
|
def list_models(): |
|
|
models = [] |
|
|
for model_id, _ in MODEL_MAPPING.items(): |
|
|
models.append({ |
|
|
"id": model_id, |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "heck", |
|
|
}) |
|
|
|
|
|
return { |
|
|
"object": "list", |
|
|
"data": models |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/hf/v1/chat/completions") |
|
|
def chat_completions(): |
|
|
|
|
|
if not verify_api_key(): |
|
|
return {"error": "Invalid API Key"}, 401 |
|
|
|
|
|
data = requests.request.json |
|
|
|
|
|
if not data or "model" not in data: |
|
|
return {"error": "Invalid request - missing model"}, 400 |
|
|
|
|
|
if not data.get("messages"): |
|
|
return {"error": "Invalid request - missing messages"}, 400 |
|
|
|
|
|
|
|
|
for msg in data["messages"]: |
|
|
if not isinstance(msg, dict): |
|
|
return {"error": "Invalid message format"}, 400 |
|
|
if "role" not in msg or "content" not in msg: |
|
|
return {"error": "Invalid message format"}, 400 |
|
|
|
|
|
|
|
|
if isinstance(msg["content"], list): |
|
|
|
|
|
for item in msg["content"]: |
|
|
if not isinstance(item, dict) or "text" not in item: |
|
|
return {"error": "Invalid content format"}, 400 |
|
|
|
|
|
msg["content"] = " ".join(item["text"] for item in msg["content"]) |
|
|
elif not isinstance(msg["content"], str): |
|
|
return {"error": "Invalid content type"}, 400 |
|
|
|
|
|
model = MODEL_MAPPING.get(data["model"]) |
|
|
if not model: |
|
|
return {"error": "Unsupported Model"}, 400 |
|
|
|
|
|
try: |
|
|
question = next((msg["content"] for msg in reversed(data["messages"]) |
|
|
if msg["role"] == "user"), None) |
|
|
except Exception as e: |
|
|
return {"error": "Failed to extract question"}, 400 |
|
|
|
|
|
if not question: |
|
|
return {"error": "No user message found"}, 400 |
|
|
|
|
|
session_id = str(uuid4()) |
|
|
|
|
|
try: |
|
|
if data.get("stream"): |
|
|
return requests.Response( |
|
|
stream_response(question, session_id, data["messages"], |
|
|
data["model"], model), |
|
|
mimetype="text/event-stream" |
|
|
) |
|
|
else: |
|
|
return normal_response(question, session_id, data["messages"], |
|
|
data["model"], model) |
|
|
except Exception as e: |
|
|
return {"error": f"Internal server error: {str(e)}"}, 500 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |