aguitauwu commited on
Commit
9884bce
·
1 Parent(s): 319f742
Files changed (1) hide show
  1. app.py +54 -12
app.py CHANGED
@@ -5,7 +5,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import time
7
 
8
- MODEL_ID = "OpceanAI/Yuuki-best"
 
 
 
 
 
9
 
10
  app = FastAPI(
11
  title="Yuuki API",
@@ -20,21 +25,31 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- print(f"Loading tokenizer from {MODEL_ID}...")
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
25
 
26
- print(f"Loading model from {MODEL_ID}...")
27
- model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_ID,
29
- torch_dtype=torch.float32
30
- ).to("cpu")
31
 
32
- model.eval()
33
- print("Model ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  class GenerateRequest(BaseModel):
37
  prompt: str = Field(..., min_length=1, max_length=4000)
 
38
  max_new_tokens: int = Field(default=120, ge=1, le=512)
39
  temperature: float = Field(default=0.7, ge=0.1, le=2.0)
40
  top_p: float = Field(default=0.95, ge=0.0, le=1.0)
@@ -42,6 +57,7 @@ class GenerateRequest(BaseModel):
42
 
43
  class GenerateResponse(BaseModel):
44
  response: str
 
45
  tokens_generated: int
46
  time_ms: int
47
 
@@ -50,9 +66,10 @@ class GenerateResponse(BaseModel):
50
  def root():
51
  return {
52
  "message": "Yuuki Local Inference API",
53
- "model": MODEL_ID,
54
  "endpoints": {
55
  "health": "GET /health",
 
56
  "generate": "POST /generate",
57
  "docs": "GET /docs"
58
  }
@@ -61,14 +78,38 @@ def root():
61
 
62
  @app.get("/health")
63
  def health():
64
- return {"status": "ok", "model": MODEL_ID}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  @app.post("/generate", response_model=GenerateResponse)
68
  def generate(req: GenerateRequest):
 
 
 
 
 
 
 
69
  try:
70
  start = time.time()
71
 
 
 
 
72
  inputs = tokenizer(
73
  req.prompt,
74
  return_tensors="pt",
@@ -96,6 +137,7 @@ def generate(req: GenerateRequest):
96
 
97
  return GenerateResponse(
98
  response=response_text.strip(),
 
99
  tokens_generated=len(new_tokens),
100
  time_ms=elapsed_ms
101
  )
 
5
  import torch
6
  import time
7
 
8
+ # Definir todos los modelos disponibles
9
+ MODELS = {
10
+ "yuuki-best": "OpceanAI/Yuuki-best",
11
+ "yuuki-3.7": "OpceanAI/Yuuki-3.7",
12
+ "yuuki-v0.1": "OpceanAI/Yuuki-v0.1"
13
+ }
14
 
15
  app = FastAPI(
16
  title="Yuuki API",
 
25
  allow_headers=["*"],
26
  )
27
 
28
+ # Cache de modelos cargados
29
+ loaded_models = {}
30
+ loaded_tokenizers = {}
31
 
 
 
 
 
 
32
 
33
+ def load_model(model_key: str):
34
+ """Lazy load: solo carga el modelo cuando se necesita"""
35
+ if model_key not in loaded_models:
36
+ print(f"Loading {model_key}...")
37
+ model_id = MODELS[model_key]
38
+
39
+ loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id)
40
+ loaded_models[model_key] = AutoModelForCausalLM.from_pretrained(
41
+ model_id,
42
+ torch_dtype=torch.float32
43
+ ).to("cpu")
44
+ loaded_models[model_key].eval()
45
+ print(f"{model_key} ready!")
46
+
47
+ return loaded_models[model_key], loaded_tokenizers[model_key]
48
 
49
 
50
  class GenerateRequest(BaseModel):
51
  prompt: str = Field(..., min_length=1, max_length=4000)
52
+ model: str = Field(default="yuuki-best", description="Model to use")
53
  max_new_tokens: int = Field(default=120, ge=1, le=512)
54
  temperature: float = Field(default=0.7, ge=0.1, le=2.0)
55
  top_p: float = Field(default=0.95, ge=0.0, le=1.0)
 
57
 
58
  class GenerateResponse(BaseModel):
59
  response: str
60
+ model: str
61
  tokens_generated: int
62
  time_ms: int
63
 
 
66
  def root():
67
  return {
68
  "message": "Yuuki Local Inference API",
69
+ "models": list(MODELS.keys()),
70
  "endpoints": {
71
  "health": "GET /health",
72
+ "models": "GET /models",
73
  "generate": "POST /generate",
74
  "docs": "GET /docs"
75
  }
 
78
 
79
  @app.get("/health")
80
  def health():
81
+ return {
82
+ "status": "ok",
83
+ "available_models": list(MODELS.keys()),
84
+ "loaded_models": list(loaded_models.keys())
85
+ }
86
+
87
+
88
+ @app.get("/models")
89
+ def list_models():
90
+ return {
91
+ "models": [
92
+ {"id": key, "name": value}
93
+ for key, value in MODELS.items()
94
+ ]
95
+ }
96
 
97
 
98
  @app.post("/generate", response_model=GenerateResponse)
99
  def generate(req: GenerateRequest):
100
+ # Validar que el modelo existe
101
+ if req.model not in MODELS:
102
+ raise HTTPException(
103
+ status_code=400,
104
+ detail=f"Invalid model. Available: {list(MODELS.keys())}"
105
+ )
106
+
107
  try:
108
  start = time.time()
109
 
110
+ # Cargar modelo (lazy load)
111
+ model, tokenizer = load_model(req.model)
112
+
113
  inputs = tokenizer(
114
  req.prompt,
115
  return_tensors="pt",
 
137
 
138
  return GenerateResponse(
139
  response=response_text.strip(),
140
+ model=req.model,
141
  tokens_generated=len(new_tokens),
142
  time_ms=elapsed_ms
143
  )