Yuuki-api / app.py
aguitauwu
Primer commit
849ec65
raw
history blame
2.56 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
MODEL_ID = "OpceanAI/Yuuki-best"
app = FastAPI(
title="Yuuki API",
description="Local inference API for Yuuki models",
version="1.0.0"
)
# CORS para que Yuuki-chat pueda llamar desde el browser
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Cargar modelo una sola vez al arrancar
print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Loading model from {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
).to("cpu")
model.eval() # Modo inferencia (más rápido, menos memoria)
print("Model ready!")
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=4000)
max_new_tokens: int = Field(default=120, ge=1, le=512)
temperature: float = Field(default=0.7, ge=0.1, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
class GenerateResponse(BaseModel):
response: str
tokens_generated: int
time_ms: int
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_ID}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
try:
start = time.time()
inputs = tokenizer(
req.prompt,
return_tensors="pt",
truncation=True,
max_length=1024
)
input_length = inputs["input_ids"].shape[1]
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
# Solo devolver tokens NUEVOS (no el prompt)
new_tokens = output[0][input_length:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
elapsed_ms = int((time.time() - start) * 1000)
return GenerateResponse(
response=response_text.strip(),
tokens_generated=len(new_tokens),
time_ms=elapsed_ms
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))