| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| import time |
|
|
| |
| MODELS = { |
| "yuuki-best": "OpceanAI/Yuuki-best", |
| "yuuki-3.7": "OpceanAI/Yuuki-3.7", |
| "yuuki-v0.1": "OpceanAI/Yuuki-v0.1" |
| } |
|
|
| app = FastAPI( |
| title="Yuuki API", |
| description="Local inference API for Yuuki models", |
| version="1.0.0" |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| loaded_models = {} |
| loaded_tokenizers = {} |
|
|
|
|
| def load_model(model_key: str): |
| """Lazy load: solo carga el modelo cuando se necesita""" |
| if model_key not in loaded_models: |
| print(f"Loading {model_key}...") |
| model_id = MODELS[model_key] |
| |
| loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id) |
| loaded_models[model_key] = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.float32 |
| ).to("cpu") |
| loaded_models[model_key].eval() |
| print(f"{model_key} ready!") |
| |
| return loaded_models[model_key], loaded_tokenizers[model_key] |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str = Field(..., min_length=1, max_length=4000) |
| model: str = Field(default="yuuki-best", description="Model to use") |
| max_new_tokens: int = Field(default=120, ge=1, le=512) |
| temperature: float = Field(default=0.7, ge=0.1, le=2.0) |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) |
|
|
|
|
| class GenerateResponse(BaseModel): |
| response: str |
| model: str |
| tokens_generated: int |
| time_ms: int |
|
|
|
|
| @app.get("/") |
| def root(): |
| return { |
| "message": "Yuuki Local Inference API", |
| "models": list(MODELS.keys()), |
| "endpoints": { |
| "health": "GET /health", |
| "models": "GET /models", |
| "generate": "POST /generate", |
| "docs": "GET /docs" |
| } |
| } |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "ok", |
| "available_models": list(MODELS.keys()), |
| "loaded_models": list(loaded_models.keys()) |
| } |
|
|
|
|
| @app.get("/models") |
| def list_models(): |
| return { |
| "models": [ |
| {"id": key, "name": value} |
| for key, value in MODELS.items() |
| ] |
| } |
|
|
|
|
| @app.post("/generate", response_model=GenerateResponse) |
| def generate(req: GenerateRequest): |
| |
| if req.model not in MODELS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Invalid model. Available: {list(MODELS.keys())}" |
| ) |
| |
| try: |
| start = time.time() |
|
|
| |
| model, tokenizer = load_model(req.model) |
|
|
| inputs = tokenizer( |
| req.prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024 |
| ) |
|
|
| input_length = inputs["input_ids"].shape[1] |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| repetition_penalty=1.1, |
| ) |
|
|
| new_tokens = output[0][input_length:] |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
| elapsed_ms = int((time.time() - start) * 1000) |
|
|
| return GenerateResponse( |
| response=response_text.strip(), |
| model=req.model, |
| tokens_generated=len(new_tokens), |
| time_ms=elapsed_ms |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|