from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response import logging import httpx import random import uvicorn import json logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) client = httpx.AsyncClient() BASE_URL_CHAT = "https://models.inference.ai.azure.com/chat/completions" BASE_URL_EMBEDDINGS = "https://models.inference.ai.azure.com/embeddings" async def process_request_body(body: bytes) -> bytes: try: data = json.loads(body) if isinstance(data, dict) and "store" in data: del data["store"] return json.dumps(data).encode() except json.JSONDecodeError: return body async def make_request(method, url, headers, body, api_keys=None, retry_count=0): try: if api_keys and len(api_keys) > 1: remaining_keys = api_keys.copy() while remaining_keys and retry_count < 3: selected_key = random.choice(remaining_keys) remaining_keys.remove(selected_key) headers = {**headers, "Authorization": f"Bearer {selected_key}"} logger.info(f"Attempting request with API key: {selected_key}") try: r = await client.request( method, url, headers=headers, content=body, timeout=600 ) if r.status_code < 400: return r logger.error(f"Request failed with key {selected_key}, status code: {r.status_code}") except Exception as e: logger.error(f"Request failed with key {selected_key}: {str(e)}") retry_count += 1 raise HTTPException(status_code=500, detail="All API keys failed") else: while retry_count < 3: single_key = api_keys[0] if api_keys else headers.get("authorization", "").replace("Bearer ", "").strip() headers = {**headers, "Authorization": f"Bearer {single_key}"} logger.info(f"Attempting request with API key: {single_key}") try: r = await client.request( method, url, headers=headers, content=body, timeout=600 ) if r.status_code < 400: return r logger.error(f"Request failed with status code: {r.status_code}") except Exception as e: logger.error(f"Request failed: {str(e)}") retry_count += 1 raise HTTPException(status_code=500, detail="Request failed after 3 retries") except Exception as e: logger.error(f"Request failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.api_route( "/v1/chat/completions", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], ) @app.api_route( "/hf/v1/chat/completions", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"], ) async def chat_completions(request: Request): target_url = BASE_URL_CHAT try: headers = dict(request.headers) if "content-length" in headers: del headers["content-length"] if "host" in headers: del headers["host"] headers["Host"] = "models.inference.ai.azure.com" api_keys = None auth_header = headers.get("authorization", "") if auth_header and auth_header.startswith("Bearer "): raw_keys = auth_header.replace("Bearer ", "").strip() api_keys = [k.strip() for k in raw_keys.split(',') if k.strip()] if "authorization" in headers: del headers["authorization"] request_body = await request.body() processed_body = await process_request_body(request_body) r = await make_request(request.method, target_url, headers, processed_body, api_keys) # return Response(content=r.content, status_code=r.status_code, headers=r.headers) response_headers = { "content-type": "application/json", } return Response(content=r.content, status_code=r.status_code, headers=response_headers) except Exception as e: logger.error(f"Forwarding request failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.api_route("/v1/embeddings", methods=["POST", "OPTIONS"]) @app.api_route("/hf/v1/embeddings", methods=["POST", "OPTIONS"]) async def embeddings(request: Request): target_url = BASE_URL_EMBEDDINGS try: headers = dict(request.headers) if "content-length" in headers: del headers["content-length"] if "host" in headers: del headers["host"] headers["Host"] = "models.inference.ai.azure.com" api_keys = None auth_header = headers.get("authorization", "") if auth_header and auth_header.startswith("Bearer "): raw_keys = auth_header.replace("Bearer ", "").strip() api_keys = [k.strip() for k in raw_keys.split(',') if k.strip()] if "authorization" in headers: del headers["authorization"] request_body = await request.body() processed_body = await process_request_body(request_body) r = await make_request(request.method, target_url, headers, processed_body, api_keys) return Response(content=r.content, status_code=r.status_code, headers=r.headers) except Exception as e: logger.error(f"Forwarding request failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/models") @app.get("/hf/v1/models") async def list_models(): models_data = { "object": "list", "data": [ {"id": "AI21-Jamba-1.5-Large", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "AI21-Jamba-1.5-Mini", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-command-r", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-command-r-08-2024", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-command-r-plus", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-command-r-plus-08-2024", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-embed-v3-english", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Cohere-embed-v3-multilingual", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Llama-3.2-90B-Vision-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Llama-3.2-11B-Vision-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Meta-Llama-3.1-405B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Meta-Llama-3.1-70B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Meta-Llama-3.1-8B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Meta-Llama-3-70B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Meta-Llama-3-8B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Mistral-large", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Mistral-large-2407", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Mistral-Nemo", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Mistral-small", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Ministral-3B", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "gpt-4o", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "gpt-4o-mini", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "o1-preview", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "o1-mini", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "text-embedding-3-large", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "text-embedding-3-small", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3.5-MoE-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3.5-vision-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3.5-mini-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-medium-128k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-medium-4k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-mini-128k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-mini-4k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-small-128k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Phi-3-small-8k-instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "jais-30b-chat", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Llama-3.3-70B-Instruct", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, {"id": "Mistral-large-2411", "object": "model", "created": 1709266800, "owned_by": "f-droid"}, ] } return models_data @app.get("/health") @app.get("/") async def health_check(): logger.info("Health check endpoint called") return {"status": "healthy"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8080)