| | from fastapi import FastAPI, HTTPException, Header, Request |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel |
| | import openai |
| | from typing import List, Optional |
| | import logging |
| | from itertools import cycle |
| | import asyncio |
| |
|
| | import uvicorn |
| |
|
| | from app import config |
| | import requests |
| | from datetime import datetime, timezone |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | API_KEYS = config.settings.API_KEYS |
| |
|
| | |
| | key_cycle = cycle(API_KEYS) |
| | key_lock = asyncio.Lock() |
| |
|
| |
|
| | class ChatRequest(BaseModel): |
| | messages: List[dict] |
| | model: str = "llama-3.2-90b-text-preview" |
| | temperature: Optional[float] = 0.7 |
| | stream: Optional[bool] = False |
| |
|
| |
|
| | async def verify_authorization(authorization: str = Header(None)): |
| | if not authorization: |
| | logger.error("Missing Authorization header") |
| | raise HTTPException(status_code=401, detail="Missing Authorization header") |
| | if not authorization.startswith("Bearer "): |
| | logger.error("Invalid Authorization header format") |
| | raise HTTPException( |
| | status_code=401, detail="Invalid Authorization header format" |
| | ) |
| | token = authorization.replace("Bearer ", "") |
| | if token not in config.settings.ALLOWED_TOKENS: |
| | logger.error("Invalid token") |
| | raise HTTPException(status_code=401, detail="Invalid token") |
| | return token |
| |
|
| |
|
| | def get_gemini_models(api_key): |
| | base_url = "https://generativelanguage.googleapis.com/v1beta" |
| | url = f"{base_url}/models?key={api_key}" |
| | |
| | try: |
| | response = requests.get(url) |
| | if response.status_code == 200: |
| | gemini_models = response.json() |
| | return convert_to_openai_format(gemini_models) |
| | else: |
| | print(f"Error: {response.status_code}") |
| | print(response.text) |
| | return None |
| | |
| | except requests.RequestException as e: |
| | print(f"Request failed: {e}") |
| | return None |
| |
|
| | def convert_to_openai_format(gemini_models): |
| | openai_format = { |
| | "object": "list", |
| | "data": [] |
| | } |
| | |
| | for model in gemini_models.get('models', []): |
| | openai_model = { |
| | "id": model['name'].split('/')[-1], |
| | "object": "model", |
| | "created": int(datetime.now(timezone.utc).timestamp()), |
| | "owned_by": "google", |
| | "permission": [], |
| | "root": model['name'], |
| | "parent": None, |
| | } |
| | openai_format["data"].append(openai_model) |
| | |
| | return openai_format |
| | |
| |
|
| | @app.get("/v1/models") |
| | @app.get("/hf/v1/models") |
| | async def list_models(authorization: str = Header(None)): |
| | await verify_authorization(authorization) |
| | async with key_lock: |
| | api_key = next(key_cycle) |
| | logger.info(f"Using API key: {api_key[:8]}...") |
| | try: |
| | response = get_gemini_models(api_key) |
| | logger.info("Successfully retrieved models list") |
| | return response |
| | except Exception as e: |
| | logger.error(f"Error listing models: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/v1/chat/completions") |
| | @app.post("/hf/v1/chat/completions") |
| | async def chat_completion(request: ChatRequest, authorization: str = Header(None)): |
| | await verify_authorization(authorization) |
| | async with key_lock: |
| | api_key = next(key_cycle) |
| | logger.info(f"Using API key: {api_key[:8]}...") |
| |
|
| | try: |
| | logger.info(f"Chat completion request - Model: {request.model}") |
| | client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL) |
| | response = client.chat.completions.create( |
| | model=request.model, |
| | messages=request.messages, |
| | temperature=request.temperature, |
| | stream=request.stream if hasattr(request, "stream") else False, |
| | ) |
| |
|
| | if hasattr(request, "stream") and request.stream: |
| | logger.info("Streaming response enabled") |
| |
|
| | async def generate(): |
| | for chunk in response: |
| | yield f"data: {chunk.model_dump_json()}\n\n" |
| |
|
| | return StreamingResponse(content=generate(), media_type="text/event-stream") |
| |
|
| | logger.info("Chat completion successful") |
| | return response |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in chat completion: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.get("/health") |
| | @app.get("/") |
| | async def health_check(): |
| | logger.info("Health check endpoint called") |
| | return {"status": "healthy"} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |