Spaces:
Paused
Paused
| import os, uuid, time, logging | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| MODEL_ID = os.getenv("MODEL_ID", "google/flan-t5-large") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "1024")) | |
| MAX_INPUT_LEN = int(os.getenv("MAX_INPUT_LEN", "512")) | |
| app = FastAPI(title="T2T OpenAI-Compatible API", version="3.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| _pipe = None | |
| _tokenizer = None | |
| def load_model(): | |
| global _pipe, _tokenizer | |
| logger.info(f"β³ Carregando {MODEL_ID} β¦") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype = torch.float32, | |
| low_cpu_mem_usage = True, | |
| ) | |
| model.eval() | |
| _pipe = pipeline( | |
| #"text2text-generation", | |
| "text-generation", | |
| model = model, | |
| tokenizer = _tokenizer, | |
| device = -1, # forΓ§a CPU | |
| ) | |
| logger.info(f"β {MODEL_ID} pronto!") | |
| # ββ Schemas (OpenAI-compatible) βββββββββββββββββββββββββββββββββββββββββββ | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ResponseFormat(BaseModel): | |
| type: str = "text" | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default=MODEL_ID) | |
| messages: List[Message] | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| max_completion_tokens: Optional[int] = None | |
| max_tokens: Optional[int] = None | |
| response_format: Optional[ResponseFormat] = None | |
| stream: bool = False | |
| class Config: | |
| populate_by_name = True | |
| class ChoiceMessage(BaseModel): | |
| role: str = "assistant" | |
| content: str | |
| class Choice(BaseModel): | |
| index: int | |
| message: ChoiceMessage | |
| finish_reason: str = "stop" | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[Choice] | |
| usage: Usage | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def messages_to_prompt(messages: List[Message]) -> str: | |
| parts = [] | |
| for m in messages: | |
| if m.role == "system": | |
| parts.append(f"Instructions: {m.content}") | |
| elif m.role == "user": | |
| parts.append(f"User: {m.content}") | |
| elif m.role == "assistant": | |
| parts.append(f"Assistant: {m.content}") | |
| return " ".join(parts) | |
| def token_count(text: str) -> int: | |
| return len(_tokenizer(text, add_special_tokens=False)["input_ids"]) | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return {"status": "ok", "model": MODEL_ID} | |
| def health(): | |
| return {"status": "healthy", "model": MODEL_ID, "ready": _pipe is not None} | |
| def list_models(): | |
| return {"object": "list", "data": [ | |
| {"id": MODEL_ID, "object": "model", "owned_by": "huggingface"} | |
| ]} | |
| def chat_completions(req: ChatCompletionRequest): | |
| if req.stream: | |
| raise HTTPException(501, "Streaming nΓ£o suportado.") | |
| if _pipe is None: | |
| raise HTTPException(503, "Modelo nΓ£o carregado.") | |
| max_tokens = req.max_completion_tokens or req.max_tokens or MAX_NEW_TOKENS | |
| prompt = messages_to_prompt(req.messages) | |
| do_sample = req.temperature > 0.05 | |
| try: | |
| output = _pipe( | |
| prompt, | |
| max_new_tokens = max_tokens, | |
| truncation = True, | |
| temperature = float(req.temperature) if do_sample else 1.0, | |
| top_p = float(req.top_p) if do_sample else 1.0, | |
| do_sample = do_sample, | |
| repetition_penalty = 1.2, | |
| return_full_text = False, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(500, str(e)) | |
| text = output[0]["generated_text"].strip() | |
| p_tok = token_count(prompt) | |
| c_tok = token_count(text) | |
| return ChatCompletionResponse( | |
| id = f"chatcmpl-{uuid.uuid4().hex[:12]}", | |
| created = int(time.time()), | |
| model = req.model, | |
| choices = [Choice(index=0, message=ChoiceMessage(content=text))], | |
| usage = Usage( | |
| prompt_tokens = p_tok, | |
| completion_tokens = c_tok, | |
| total_tokens = p_tok + c_tok, | |
| ), | |
| ) |