File size: 4,266 Bytes
99e15e7
f97ce08
 
f5c37d3
992fbe7
 
f97ce08
 
 
99e15e7
 
 
562d032
99e15e7
f5c37d3
992fbe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f97ce08
 
 
 
 
 
562d032
f5c37d3
 
 
 
99e15e7
 
 
 
 
992fbe7
 
99e15e7
992fbe7
99e15e7
 
 
992fbe7
99e15e7
992fbe7
99e15e7
992fbe7
 
 
99e15e7
992fbe7
 
99e15e7
992fbe7
 
99e15e7
 
992fbe7
 
 
 
 
 
99e15e7
992fbe7
 
 
 
 
 
 
562d032
992fbe7
562d032
f5c37d3
992fbe7
f5c37d3
992fbe7
562d032
f97ce08
 
 
992fbe7
 
f97ce08
992fbe7
f5c37d3
f97ce08
992fbe7
 
f97ce08
99e15e7
992fbe7
 
99e15e7
992fbe7
 
 
 
 
 
 
99e15e7
992fbe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, HTTPException, Security, Header
from pydantic import BaseModel
from llama_cpp import Llama
from typing import List, Optional
import httpx
import os

app = FastAPI()

SERVICE_API_KEY = os.environ.get("SERVICE_API_KEY")
SERVICE_API_URL = "https://api.groq.com/openai/v1/chat/completions"
SERVICE_MODEL = "llama-3.3-70b-versatile"

EDYX_ACCESS_TOKEN = os.environ.get("EDYX_ACCESS_TOKEN")

SYSTEM_PROMPT = """You are a helpful, harmless, and honest AI assistant.
Provide clear and conversational responses."""

local_llm = None

def get_local_llm():
    global local_llm
    if local_llm is None:
        print("Loading local fallback model...")
        local_llm = Llama(
            model_path="/models/model.gguf",
            n_ctx=4096,
            n_threads=2,
            n_batch=128,
            verbose=False
        )
    return local_llm

class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: List[Message]
    max_tokens: Optional[int] = 1024
    temperature: Optional[float] = 0.7
    repetition_penalty: Optional[float] = 1.1

async def verify_token(x_edyx_token: str = Header(None)):
    if EDYX_ACCESS_TOKEN and x_edyx_token != EDYX_ACCESS_TOKEN:
        raise HTTPException(status_code=403, detail="Unauthorized: Invalid Access Token")
    return x_edyx_token

@app.get("/")
def root():
    return {"status": "edyx convo model running", "mode": "accelerated-primary"}

async def call_service_api(messages: List[Message], max_tokens: int, temperature: float):
    if not SERVICE_API_KEY:
        raise Exception("Service API key not configured")
    
    service_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for m in messages:
        service_messages.append({"role": m.role, "content": m.content})
    
    async with httpx.AsyncClient(timeout=45.0) as client:
        response = await client.post(
            SERVICE_API_URL,
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {SERVICE_API_KEY}"
            },
            json={
                "model": SERVICE_MODEL,
                "messages": service_messages,
                "max_tokens": max_tokens,
                "temperature": temperature
            }
        )
        
        if response.status_code != 200:
            raise Exception(f"Service API error: {response.status_code} - {response.text}")
        
        data = response.json()
        return data["choices"][0]["message"]["content"], data["usage"]["total_tokens"]

def call_local_model(messages: List[Message], max_tokens: int, temperature: float, repetition_penalty: float):
    llm = get_local_llm()
    
    prompt = SYSTEM_PROMPT + "\n\n"
    for m in messages:
        role = m.role.lower()
        if role == "system":
            prompt = f"{m.content}\n\n"
        else:
            prompt += f"{role}: {m.content}\n"
    prompt += "assistant:"

    output = llm(
        prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=0.9,
        repeat_penalty=repetition_penalty,
        stop=["user:", "assistant:", "<|end|>", "User:"]
    )
    
    return output["choices"][0]["text"].strip(), output["usage"]["total_tokens"]

@app.post("/v1/chat", dependencies=[Security(verify_token)])
async def chat(req: ChatRequest):
    try:
        text, tokens = await call_service_api(req.messages, req.max_tokens, req.temperature)
        return {
            "model": "edyx-convo",
            "text": text,
            "tokens": tokens,
            "source": "primary"
        }
    except Exception as e:
        print(f"Service API failed: {e}, falling back to local model...")
    
    try:
        text, tokens = call_local_model(
            req.messages,
            req.max_tokens,
            req.temperature,
            req.repetition_penalty
        )
        return {
            "model": "edyx-convo",
            "text": text,
            "tokens": tokens,
            "source": "fallback"
        }
    except Exception as e:
        return {
            "model": "edyx-convo",
            "text": f"Error: Both primary and fallback failed. {str(e)}",
            "tokens": 0,
            "source": "error"
        }