import os, time, uuid, json from typing import List, Optional from fastapi import FastAPI from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForSeq2SeqLM MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") _tokenizer = None _model = None def load_model(): global _tokenizer, _model if _tokenizer is None or _model is None: _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) def messages_to_prompt(messages: List[dict]) -> str: system = "" convo = [] for m in messages: role = (m.get("role") or "").lower() content = (m.get("content") or "").strip() if role == "system": system += content + "\n" elif role == "user": convo.append(f"User: {content}") else: convo.append(f"Assistant: {content}") return ( "You are a strict instruction follower.\n" "If the user requests JSON, return ONLY valid JSON with no extra text.\n" f"{system}\n" + "\n".join(convo) + "\nAssistant:" ) def generate(prompt: str, max_new_tokens: int = 256) -> str: load_model() inputs = _tokenizer(prompt, return_tensors="pt", truncation=True) out = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) return _tokenizer.decode(out[0], skip_special_tokens=True).strip() app = FastAPI(title="My AI API (OpenAI-ish)") class ChatMessage(BaseModel): role: str content: str class ChatReq(BaseModel): model: Optional[str] = None messages: List[ChatMessage] max_tokens: int = Field(default=256, ge=1, le=1024) temperature: float = Field(default=0.0, ge=0.0, le=2.0) @app.get("/health") def health(): return {"status": "ok", "model": MODEL_NAME} @app.get("/v1/models") def models(): return {"object": "list", "data": [{"id": MODEL_NAME, "object": "model", "owned_by": "me"}]} @app.post("/v1/chat/completions") def chat_completions(req: ChatReq): t0 = time.time() prompt = messages_to_prompt([m.model_dump() for m in req.messages]) text = generate(prompt, max_new_tokens=req.max_tokens) user_text = " ".join([m.content.lower() for m in req.messages if m.role.lower() == "user"]) if "json" in user_text: a = text.find("{"); b = text.rfind("}") if a != -1 and b != -1 and b > a: candidate = text[a:b+1] try: json.loads(candidate) text = candidate except Exception: pass return { "id": f"chatcmpl-{uuid.uuid4().hex[:24]}", "object": "chat.completion", "created": int(time.time()), "model": req.model or MODEL_NAME, "choices": [ {"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"} ], "usage": {"prompt_tokens": None, "completion_tokens": None, "total_tokens": None}, "latency_ms": int((time.time() - t0) * 1000), }