| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| import time |
|
|
| MODEL_ID = "OpceanAI/Yuuki-best" |
|
|
| app = FastAPI( |
| title="Yuuki API", |
| description="Local inference API for Yuuki models", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| print(f"Loading tokenizer from {MODEL_ID}...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| print(f"Loading model from {MODEL_ID}...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32 |
| ).to("cpu") |
|
|
| model.eval() |
| print("Model ready!") |
|
|
|
|
| class GenerateRequest(BaseModel): |
| prompt: str = Field(..., min_length=1, max_length=4000) |
| max_new_tokens: int = Field(default=120, ge=1, le=512) |
| temperature: float = Field(default=0.7, ge=0.1, le=2.0) |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) |
|
|
|
|
| class GenerateResponse(BaseModel): |
| response: str |
| tokens_generated: int |
| time_ms: int |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model": MODEL_ID} |
|
|
|
|
| @app.post("/generate", response_model=GenerateResponse) |
| def generate(req: GenerateRequest): |
| try: |
| start = time.time() |
|
|
| inputs = tokenizer( |
| req.prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=1024 |
| ) |
|
|
| input_length = inputs["input_ids"].shape[1] |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=req.max_new_tokens, |
| temperature=req.temperature, |
| top_p=req.top_p, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| repetition_penalty=1.1, |
| ) |
|
|
| |
| new_tokens = output[0][input_length:] |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
|
|
| elapsed_ms = int((time.time() - start) * 1000) |
|
|
| return GenerateResponse( |
| response=response_text.strip(), |
| tokens_generated=len(new_tokens), |
| time_ms=elapsed_ms |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|