from fastapi import FastAPI, HTTPException, Security, Header from pydantic import BaseModel from llama_cpp import Llama from typing import List, Optional import httpx import os app = FastAPI() SERVICE_API_KEY = os.environ.get("SERVICE_API_KEY") SERVICE_API_URL = "https://api.groq.com/openai/v1/chat/completions" SERVICE_MODEL = "llama-3.3-70b-versatile" EDYX_ACCESS_TOKEN = os.environ.get("EDYX_ACCESS_TOKEN") SYSTEM_PROMPT = """You are a helpful, harmless, and honest AI assistant. Provide clear and conversational responses.""" local_llm = None def get_local_llm(): global local_llm if local_llm is None: print("Loading local fallback model...") local_llm = Llama( model_path="/models/model.gguf", n_ctx=4096, n_threads=2, n_batch=128, verbose=False ) return local_llm class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] max_tokens: Optional[int] = 1024 temperature: Optional[float] = 0.7 repetition_penalty: Optional[float] = 1.1 async def verify_token(x_edyx_token: str = Header(None)): if EDYX_ACCESS_TOKEN and x_edyx_token != EDYX_ACCESS_TOKEN: raise HTTPException(status_code=403, detail="Unauthorized: Invalid Access Token") return x_edyx_token @app.get("/") def root(): return {"status": "edyx convo model running", "mode": "accelerated-primary"} async def call_service_api(messages: List[Message], max_tokens: int, temperature: float): if not SERVICE_API_KEY: raise Exception("Service API key not configured") service_messages = [{"role": "system", "content": SYSTEM_PROMPT}] for m in messages: service_messages.append({"role": m.role, "content": m.content}) async with httpx.AsyncClient(timeout=45.0) as client: response = await client.post( SERVICE_API_URL, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {SERVICE_API_KEY}" }, json={ "model": SERVICE_MODEL, "messages": service_messages, "max_tokens": max_tokens, "temperature": temperature } ) if response.status_code != 200: raise Exception(f"Service API error: {response.status_code} - {response.text}") data = response.json() return data["choices"][0]["message"]["content"], data["usage"]["total_tokens"] def call_local_model(messages: List[Message], max_tokens: int, temperature: float, repetition_penalty: float): llm = get_local_llm() prompt = SYSTEM_PROMPT + "\n\n" for m in messages: role = m.role.lower() if role == "system": prompt = f"{m.content}\n\n" else: prompt += f"{role}: {m.content}\n" prompt += "assistant:" output = llm( prompt, max_tokens=max_tokens, temperature=temperature, top_p=0.9, repeat_penalty=repetition_penalty, stop=["user:", "assistant:", "<|end|>", "User:"] ) return output["choices"][0]["text"].strip(), output["usage"]["total_tokens"] @app.post("/v1/chat", dependencies=[Security(verify_token)]) async def chat(req: ChatRequest): try: text, tokens = await call_service_api(req.messages, req.max_tokens, req.temperature) return { "model": "edyx-convo", "text": text, "tokens": tokens, "source": "primary" } except Exception as e: print(f"Service API failed: {e}, falling back to local model...") try: text, tokens = call_local_model( req.messages, req.max_tokens, req.temperature, req.repetition_penalty ) return { "model": "edyx-convo", "text": text, "tokens": tokens, "source": "fallback" } except Exception as e: return { "model": "edyx-convo", "text": f"Error: Both primary and fallback failed. {str(e)}", "tokens": 0, "source": "error" }