google-t5-base / app /main.py
caarleexx's picture
Update app/main.py
f00ca36 verified
import os, uuid, time, logging
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
MODEL_ID = os.getenv("MODEL_ID", "google/flan-t5-large")
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024"))
MAX_INPUT_LEN = int(os.getenv("MAX_INPUT_LEN", "512"))
app = FastAPI(title="T2T OpenAI-Compatible API", version="3.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_methods=["*"], allow_headers=["*"],
)
_pipe = None
_tokenizer = None
@app.on_event("startup")
def load_model():
global _pipe, _tokenizer
logger.info(f"⏳ Carregando {MODEL_ID} …")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_ID,
torch_dtype = torch.float32,
low_cpu_mem_usage = True,
)
model.eval()
_pipe = pipeline(
#"text2text-generation",
"text-generation",
model = model,
tokenizer = _tokenizer,
device = -1, # forΓ§a CPU
)
logger.info(f"βœ… {MODEL_ID} pronto!")
# ── Schemas (OpenAI-compatible) ───────────────────────────────────────────
class Message(BaseModel):
role: str
content: str
class ResponseFormat(BaseModel):
type: str = "text"
class ChatCompletionRequest(BaseModel):
model: str = Field(default=MODEL_ID)
messages: List[Message]
temperature: float = 0.7
top_p: float = 0.9
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
response_format: Optional[ResponseFormat] = None
stream: bool = False
class Config:
populate_by_name = True
class ChoiceMessage(BaseModel):
role: str = "assistant"
content: str
class Choice(BaseModel):
index: int
message: ChoiceMessage
finish_reason: str = "stop"
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Choice]
usage: Usage
# ── Helpers ───────────────────────────────────────────────────────────────
def messages_to_prompt(messages: List[Message]) -> str:
parts = []
for m in messages:
if m.role == "system":
parts.append(f"Instructions: {m.content}")
elif m.role == "user":
parts.append(f"User: {m.content}")
elif m.role == "assistant":
parts.append(f"Assistant: {m.content}")
return " ".join(parts)
def token_count(text: str) -> int:
return len(_tokenizer(text, add_special_tokens=False)["input_ids"])
# ── Endpoints ─────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_ID}
@app.get("/health")
def health():
return {"status": "healthy", "model": MODEL_ID, "ready": _pipe is not None}
@app.get("/v1/models")
def list_models():
return {"object": "list", "data": [
{"id": MODEL_ID, "object": "model", "owned_by": "huggingface"}
]}
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
def chat_completions(req: ChatCompletionRequest):
if req.stream:
raise HTTPException(501, "Streaming nΓ£o suportado.")
if _pipe is None:
raise HTTPException(503, "Modelo nΓ£o carregado.")
max_tokens = req.max_completion_tokens or req.max_tokens or MAX_NEW_TOKENS
prompt = messages_to_prompt(req.messages)
do_sample = req.temperature > 0.05
try:
output = _pipe(
prompt,
max_new_tokens = max_tokens,
truncation = True,
temperature = float(req.temperature) if do_sample else 1.0,
top_p = float(req.top_p) if do_sample else 1.0,
do_sample = do_sample,
repetition_penalty = 1.2,
return_full_text = False,
)
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(500, str(e))
text = output[0]["generated_text"].strip()
p_tok = token_count(prompt)
c_tok = token_count(text)
return ChatCompletionResponse(
id = f"chatcmpl-{uuid.uuid4().hex[:12]}",
created = int(time.time()),
model = req.model,
choices = [Choice(index=0, message=ChoiceMessage(content=text))],
usage = Usage(
prompt_tokens = p_tok,
completion_tokens = c_tok,
total_tokens = p_tok + c_tok,
),
)