Adi362's picture
Update app.py
710c75c verified
from fastapi import FastAPI, HTTPException, Security, Header
from pydantic import BaseModel
from llama_cpp import Llama
from typing import List, Optional
import httpx
import os
app = FastAPI()
SERVICE_API_KEY = os.environ.get("SERVICE_API_KEY")
SERVICE_API_URL = "https://api.groq.com/openai/v1/chat/completions"
SERVICE_MODEL = "llama-3.3-70b-versatile"
EDYX_ACCESS_TOKEN = os.environ.get("EDYX_ACCESS_TOKEN")
SYSTEM_PROMPT = """You are a helpful, accurate, and thorough AI assistant.
Provide detailed, well-reasoned responses with clear explanations.
Focus on accuracy and depth in your answers."""
local_llm = None
def get_local_llm():
global local_llm
if local_llm is None:
print("Loading local fallback model...")
local_llm = Llama(
model_path="/models/model.gguf",
n_ctx=4096,
n_threads=2,
n_batch=128,
verbose=False,
)
return local_llm
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.3
top_p: Optional[float] = 0.9
async def verify_token(x_edyx_token: str = Header(None)):
if EDYX_ACCESS_TOKEN and x_edyx_token != EDYX_ACCESS_TOKEN:
raise HTTPException(status_code=403, detail="Unauthorized: Invalid Access Token")
return x_edyx_token
@app.get("/")
def root():
return {"status": "edyx balanced model running", "mode": "accelerated-primary"}
async def call_service_api(messages: List[Message], max_tokens: int, temperature: float):
if not SERVICE_API_KEY:
raise Exception("Service API key not configured")
service_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for m in messages:
service_messages.append({"role": m.role, "content": m.content})
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
SERVICE_API_URL,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {SERVICE_API_KEY}"
},
json={
"model": SERVICE_MODEL,
"messages": service_messages,
"max_tokens": max_tokens,
"temperature": temperature
}
)
if response.status_code != 200:
raise Exception(f"Service API error: {response.status_code} - {response.text}")
data = response.json()
return data["choices"][0]["message"]["content"], data["usage"]["total_tokens"]
def call_local_model(messages: List[Message], max_tokens: int, temperature: float):
llm = get_local_llm()
prompt = f"{SYSTEM_PROMPT}\n\n"
for m in messages:
prompt += f"{m.role.upper()}: {m.content}\n"
prompt += "ASSISTANT:\n"
out = llm(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
stop=["USER:", "SYSTEM:"],
)
return out["choices"][0]["text"].strip(), out["usage"]["total_tokens"]
@app.post("/v1/chat", dependencies=[Security(verify_token)])
async def chat(req: ChatRequest):
max_tokens = min(req.max_tokens or 2048, 4096)
temperature = req.temperature or 0.3
try:
text, tokens = await call_service_api(req.messages, max_tokens, temperature)
return {
"model": "edyx-balanced",
"text": text,
"tokens": tokens,
"source": "primary"
}
except Exception as e:
print(f"Service API failed: {e}, falling back to local model...")
try:
text, tokens = call_local_model(req.messages, min(max_tokens, 512), temperature)
return {
"model": "edyx-balanced",
"text": text,
"tokens": tokens,
"source": "fallback"
}
except Exception as e:
return {
"model": "edyx-balanced",
"text": f"Error: Both primary and fallback failed. {str(e)}",
"tokens": 0,
"source": "error"
}