AlexGPT / main.py
MarneMorgan's picture
Create main.py
5639b62 verified
import os, time, uuid, json
from typing import List, Optional
from fastapi import FastAPI
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small")
_tokenizer = None
_model = None
def load_model():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
def messages_to_prompt(messages: List[dict]) -> str:
system = ""
convo = []
for m in messages:
role = (m.get("role") or "").lower()
content = (m.get("content") or "").strip()
if role == "system":
system += content + "\n"
elif role == "user":
convo.append(f"User: {content}")
else:
convo.append(f"Assistant: {content}")
return (
"You are a strict instruction follower.\n"
"If the user requests JSON, return ONLY valid JSON with no extra text.\n"
f"{system}\n"
+ "\n".join(convo)
+ "\nAssistant:"
)
def generate(prompt: str, max_new_tokens: int = 256) -> str:
load_model()
inputs = _tokenizer(prompt, return_tensors="pt", truncation=True)
out = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
return _tokenizer.decode(out[0], skip_special_tokens=True).strip()
app = FastAPI(title="My AI API (OpenAI-ish)")
class ChatMessage(BaseModel):
role: str
content: str
class ChatReq(BaseModel):
model: Optional[str] = None
messages: List[ChatMessage]
max_tokens: int = Field(default=256, ge=1, le=1024)
temperature: float = Field(default=0.0, ge=0.0, le=2.0)
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME}
@app.get("/v1/models")
def models():
return {"object": "list", "data": [{"id": MODEL_NAME, "object": "model", "owned_by": "me"}]}
@app.post("/v1/chat/completions")
def chat_completions(req: ChatReq):
t0 = time.time()
prompt = messages_to_prompt([m.model_dump() for m in req.messages])
text = generate(prompt, max_new_tokens=req.max_tokens)
user_text = " ".join([m.content.lower() for m in req.messages if m.role.lower() == "user"])
if "json" in user_text:
a = text.find("{"); b = text.rfind("}")
if a != -1 and b != -1 and b > a:
candidate = text[a:b+1]
try:
json.loads(candidate)
text = candidate
except Exception:
pass
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:24]}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model or MODEL_NAME,
"choices": [
{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}
],
"usage": {"prompt_tokens": None, "completion_tokens": None, "total_tokens": None},
"latency_ms": int((time.time() - t0) * 1000),
}