import os, uuid, time, logging from typing import List, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM import torch logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) MODEL_ID = os.getenv("MODEL_ID", "google/flan-t5-large") MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024")) MAX_INPUT_LEN = int(os.getenv("MAX_INPUT_LEN", "512")) app = FastAPI(title="T2T OpenAI-Compatible API", version="3.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) _pipe = None _tokenizer = None @app.on_event("startup") def load_model(): global _pipe, _tokenizer logger.info(f"⏳ Carregando {MODEL_ID} …") _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_ID, torch_dtype = torch.float32, low_cpu_mem_usage = True, ) model.eval() _pipe = pipeline( #"text2text-generation", "text-generation", model = model, tokenizer = _tokenizer, device = -1, # força CPU ) logger.info(f"✅ {MODEL_ID} pronto!") # ── Schemas (OpenAI-compatible) ─────────────────────────────────────────── class Message(BaseModel): role: str content: str class ResponseFormat(BaseModel): type: str = "text" class ChatCompletionRequest(BaseModel): model: str = Field(default=MODEL_ID) messages: List[Message] temperature: float = 0.7 top_p: float = 0.9 max_completion_tokens: Optional[int] = None max_tokens: Optional[int] = None response_format: Optional[ResponseFormat] = None stream: bool = False class Config: populate_by_name = True class ChoiceMessage(BaseModel): role: str = "assistant" content: str class Choice(BaseModel): index: int message: ChoiceMessage finish_reason: str = "stop" class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Choice] usage: Usage # ── Helpers ─────────────────────────────────────────────────────────────── def messages_to_prompt(messages: List[Message]) -> str: parts = [] for m in messages: if m.role == "system": parts.append(f"Instructions: {m.content}") elif m.role == "user": parts.append(f"User: {m.content}") elif m.role == "assistant": parts.append(f"Assistant: {m.content}") return " ".join(parts) def token_count(text: str) -> int: return len(_tokenizer(text, add_special_tokens=False)["input_ids"]) # ── Endpoints ───────────────────────────────────────────────────────────── @app.get("/") def root(): return {"status": "ok", "model": MODEL_ID} @app.get("/health") def health(): return {"status": "healthy", "model": MODEL_ID, "ready": _pipe is not None} @app.get("/v1/models") def list_models(): return {"object": "list", "data": [ {"id": MODEL_ID, "object": "model", "owned_by": "huggingface"} ]} @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) def chat_completions(req: ChatCompletionRequest): if req.stream: raise HTTPException(501, "Streaming não suportado.") if _pipe is None: raise HTTPException(503, "Modelo não carregado.") max_tokens = req.max_completion_tokens or req.max_tokens or MAX_NEW_TOKENS prompt = messages_to_prompt(req.messages) do_sample = req.temperature > 0.05 try: output = _pipe( prompt, max_new_tokens = max_tokens, truncation = True, temperature = float(req.temperature) if do_sample else 1.0, top_p = float(req.top_p) if do_sample else 1.0, do_sample = do_sample, repetition_penalty = 1.2, return_full_text = False, ) except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(500, str(e)) text = output[0]["generated_text"].strip() p_tok = token_count(prompt) c_tok = token_count(text) return ChatCompletionResponse( id = f"chatcmpl-{uuid.uuid4().hex[:12]}", created = int(time.time()), model = req.model, choices = [Choice(index=0, message=ChoiceMessage(content=text))], usage = Usage( prompt_tokens = p_tok, completion_tokens = c_tok, total_tokens = p_tok + c_tok, ), )