Spaces:
Sleeping
Sleeping
| import os, time, uuid, json | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") | |
| _tokenizer = None | |
| _model = None | |
| def load_model(): | |
| global _tokenizer, _model | |
| if _tokenizer is None or _model is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| def messages_to_prompt(messages: List[dict]) -> str: | |
| system = "" | |
| convo = [] | |
| for m in messages: | |
| role = (m.get("role") or "").lower() | |
| content = (m.get("content") or "").strip() | |
| if role == "system": | |
| system += content + "\n" | |
| elif role == "user": | |
| convo.append(f"User: {content}") | |
| else: | |
| convo.append(f"Assistant: {content}") | |
| return ( | |
| "You are a strict instruction follower.\n" | |
| "If the user requests JSON, return ONLY valid JSON with no extra text.\n" | |
| f"{system}\n" | |
| + "\n".join(convo) | |
| + "\nAssistant:" | |
| ) | |
| def generate(prompt: str, max_new_tokens: int = 256) -> str: | |
| load_model() | |
| inputs = _tokenizer(prompt, return_tensors="pt", truncation=True) | |
| out = _model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) | |
| return _tokenizer.decode(out[0], skip_special_tokens=True).strip() | |
| app = FastAPI(title="My AI API (OpenAI-ish)") | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatReq(BaseModel): | |
| model: Optional[str] = None | |
| messages: List[ChatMessage] | |
| max_tokens: int = Field(default=256, ge=1, le=1024) | |
| temperature: float = Field(default=0.0, ge=0.0, le=2.0) | |
| def health(): | |
| return {"status": "ok", "model": MODEL_NAME} | |
| def models(): | |
| return {"object": "list", "data": [{"id": MODEL_NAME, "object": "model", "owned_by": "me"}]} | |
| def chat_completions(req: ChatReq): | |
| t0 = time.time() | |
| prompt = messages_to_prompt([m.model_dump() for m in req.messages]) | |
| text = generate(prompt, max_new_tokens=req.max_tokens) | |
| user_text = " ".join([m.content.lower() for m in req.messages if m.role.lower() == "user"]) | |
| if "json" in user_text: | |
| a = text.find("{"); b = text.rfind("}") | |
| if a != -1 and b != -1 and b > a: | |
| candidate = text[a:b+1] | |
| try: | |
| json.loads(candidate) | |
| text = candidate | |
| except Exception: | |
| pass | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex[:24]}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": req.model or MODEL_NAME, | |
| "choices": [ | |
| {"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"} | |
| ], | |
| "usage": {"prompt_tokens": None, "completion_tokens": None, "total_tokens": None}, | |
| "latency_ms": int((time.time() - t0) * 1000), | |
| } |