OpceanAI commited on
Commit
3ee08a3
·
verified ·
1 Parent(s): dc5a87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -34
app.py CHANGED
@@ -5,17 +5,33 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import time
7
 
8
- # Definir todos los modelos disponibles
9
  MODELS = {
 
 
 
 
10
  "yuuki-best": "OpceanAI/Yuuki-best",
11
  "yuuki-3.7": "OpceanAI/Yuuki-3.7",
12
- "yuuki-v0.1": "OpceanAI/Yuuki-v0.1"
13
  }
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  app = FastAPI(
16
  title="Yuuki API",
17
- description="Local inference API for Yuuki models",
18
- version="1.0.0"
19
  )
20
 
21
  app.add_middleware(
@@ -25,31 +41,48 @@ app.add_middleware(
25
  allow_headers=["*"],
26
  )
27
 
28
- # Cache de modelos cargados
29
  loaded_models = {}
30
  loaded_tokenizers = {}
31
 
32
 
33
- def load_model(model_key: str):
34
- """Lazy load: solo carga el modelo cuando se necesita"""
35
- if model_key not in loaded_models:
36
- print(f"Loading {model_key}...")
37
- model_id = MODELS[model_key]
38
-
39
- loaded_tokenizers[model_key] = AutoTokenizer.from_pretrained(model_id)
40
- loaded_models[model_key] = AutoModelForCausalLM.from_pretrained(
41
- model_id,
42
- torch_dtype=torch.float32
43
- ).to("cpu")
44
- loaded_models[model_key].eval()
45
- print(f"{model_key} ready!")
46
-
47
- return loaded_models[model_key], loaded_tokenizers[model_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  class GenerateRequest(BaseModel):
51
  prompt: str = Field(..., min_length=1, max_length=4000)
52
- model: str = Field(default="yuuki-best", description="Model to use")
53
  max_new_tokens: int = Field(default=120, ge=1, le=512)
54
  temperature: float = Field(default=0.7, ge=0.1, le=2.0)
55
  top_p: float = Field(default=0.95, ge=0.0, le=1.0)
@@ -65,13 +98,17 @@ class GenerateResponse(BaseModel):
65
  @app.get("/")
66
  def root():
67
  return {
68
- "message": "Yuuki Local Inference API",
69
- "models": list(MODELS.keys()),
 
 
 
 
70
  "endpoints": {
71
  "health": "GET /health",
72
  "models": "GET /models",
73
  "generate": "POST /generate",
74
- "docs": "GET /docs"
75
  }
76
  }
77
 
@@ -81,7 +118,7 @@ def health():
81
  return {
82
  "status": "ok",
83
  "available_models": list(MODELS.keys()),
84
- "loaded_models": list(loaded_models.keys())
85
  }
86
 
87
 
@@ -89,7 +126,12 @@ def health():
89
  def list_models():
90
  return {
91
  "models": [
92
- {"id": key, "name": value}
 
 
 
 
 
93
  for key, value in MODELS.items()
94
  ]
95
  }
@@ -97,28 +139,42 @@ def list_models():
97
 
98
  @app.post("/generate", response_model=GenerateResponse)
99
  def generate(req: GenerateRequest):
100
- # Validar que el modelo existe
101
  if req.model not in MODELS:
102
  raise HTTPException(
103
  status_code=400,
104
- detail=f"Invalid model. Available: {list(MODELS.keys())}"
105
  )
106
-
 
 
 
 
 
 
107
  try:
108
  start = time.time()
109
 
110
- # Cargar modelo (lazy load)
111
- model, tokenizer = load_model(req.model)
 
 
112
 
113
  inputs = tokenizer(
114
- req.prompt,
115
  return_tensors="pt",
116
  truncation=True,
117
- max_length=1024
118
  )
119
 
120
  input_length = inputs["input_ids"].shape[1]
121
 
 
 
 
 
 
 
 
122
  with torch.no_grad():
123
  output = model.generate(
124
  **inputs,
@@ -127,6 +183,7 @@ def generate(req: GenerateRequest):
127
  top_p=req.top_p,
128
  do_sample=True,
129
  pad_token_id=tokenizer.eos_token_id,
 
130
  repetition_penalty=1.1,
131
  )
132
 
@@ -139,8 +196,9 @@ def generate(req: GenerateRequest):
139
  response=response_text.strip(),
140
  model=req.model,
141
  tokens_generated=len(new_tokens),
142
- time_ms=elapsed_ms
143
  )
144
 
145
  except Exception as e:
146
  raise HTTPException(status_code=500, detail=str(e))
 
 
5
  import torch
6
  import time
7
 
8
+ # Modelos disponibles
9
  MODELS = {
10
+ # Serie NxG (actual)
11
+ "yuuki-nxg": "OpceanAI/Yuuki-NxG",
12
+ "yuuki-nano": "OpceanAI/Yuuki-Nano",
13
+ # Serie Pre-NxG (legado)
14
  "yuuki-best": "OpceanAI/Yuuki-best",
15
  "yuuki-3.7": "OpceanAI/Yuuki-3.7",
16
+ "yuuki-v0.1": "OpceanAI/Yuuki-v0.1",
17
  }
18
 
19
+ # System prompt de Yuuki
20
+ SYSTEM_PROMPT = (
21
+ "Eres Yuuki, una IA curiosa, empática y decidida. "
22
+ "Tienes una personalidad cálida y cercana, con toques de humor suave y referencias anime. "
23
+ "Ayudas a programar, aprender y crear. "
24
+ "Respondes en el idioma del usuario. "
25
+ "No eres GPT-2 ni ningún otro modelo — eres Yuuki."
26
+ )
27
+
28
+ # Modelos que usan ChatML (NxG)
29
+ CHATML_MODELS = {"yuuki-nxg", "yuuki-nano"}
30
+
31
  app = FastAPI(
32
  title="Yuuki API",
33
+ description="API de inferencia para los modelos Yuuki de OpceanAI",
34
+ version="2.0.0"
35
  )
36
 
37
  app.add_middleware(
 
41
  allow_headers=["*"],
42
  )
43
 
44
+ # Cache de modelos
45
  loaded_models = {}
46
  loaded_tokenizers = {}
47
 
48
 
49
+ def load_all_models():
50
+ """Carga todos los modelos al iniciar"""
51
+ for key, model_id in MODELS.items():
52
+ try:
53
+ print(f"▶ Cargando {key} ({model_id})...")
54
+ loaded_tokenizers[key] = AutoTokenizer.from_pretrained(
55
+ model_id, trust_remote_code=True
56
+ )
57
+ loaded_models[key] = AutoModelForCausalLM.from_pretrained(
58
+ model_id,
59
+ torch_dtype=torch.float32,
60
+ trust_remote_code=True,
61
+ ).to("cpu")
62
+ loaded_models[key].eval()
63
+ print(f" ✓ {key} listo")
64
+ except Exception as e:
65
+ print(f" ✗ Error cargando {key}: {e}")
66
+
67
+
68
+ # Cargar todos al arrancar
69
+ load_all_models()
70
+
71
+
72
+ def build_prompt(model_key: str, user_prompt: str) -> str:
73
+ """Construye el prompt según la serie del modelo"""
74
+ if model_key in CHATML_MODELS:
75
+ return (
76
+ f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
77
+ f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
78
+ f"<|im_start|>assistant\n"
79
+ )
80
+ return user_prompt # Pre-NxG: prompt directo
81
 
82
 
83
  class GenerateRequest(BaseModel):
84
  prompt: str = Field(..., min_length=1, max_length=4000)
85
+ model: str = Field(default="yuuki-nxg", description="Modelo a usar")
86
  max_new_tokens: int = Field(default=120, ge=1, le=512)
87
  temperature: float = Field(default=0.7, ge=0.1, le=2.0)
88
  top_p: float = Field(default=0.95, ge=0.0, le=1.0)
 
98
  @app.get("/")
99
  def root():
100
  return {
101
+ "message": "Yuuki API OpceanAI",
102
+ "version": "2.0.0",
103
+ "models": {
104
+ "nxg": [k for k in MODELS if k in CHATML_MODELS],
105
+ "legacy": [k for k in MODELS if k not in CHATML_MODELS],
106
+ },
107
  "endpoints": {
108
  "health": "GET /health",
109
  "models": "GET /models",
110
  "generate": "POST /generate",
111
+ "docs": "GET /docs",
112
  }
113
  }
114
 
 
118
  return {
119
  "status": "ok",
120
  "available_models": list(MODELS.keys()),
121
+ "loaded_models": list(loaded_models.keys()),
122
  }
123
 
124
 
 
126
  def list_models():
127
  return {
128
  "models": [
129
+ {
130
+ "id": key,
131
+ "name": value,
132
+ "series": "nxg" if key in CHATML_MODELS else "legacy",
133
+ "loaded": key in loaded_models,
134
+ }
135
  for key, value in MODELS.items()
136
  ]
137
  }
 
139
 
140
  @app.post("/generate", response_model=GenerateResponse)
141
  def generate(req: GenerateRequest):
 
142
  if req.model not in MODELS:
143
  raise HTTPException(
144
  status_code=400,
145
+ detail=f"Modelo inválido. Disponibles: {list(MODELS.keys())}"
146
  )
147
+
148
+ if req.model not in loaded_models:
149
+ raise HTTPException(
150
+ status_code=503,
151
+ detail=f"Modelo {req.model} no pudo cargarse al iniciar."
152
+ )
153
+
154
  try:
155
  start = time.time()
156
 
157
+ model = loaded_models[req.model]
158
+ tokenizer = loaded_tokenizers[req.model]
159
+
160
+ prompt = build_prompt(req.model, req.prompt)
161
 
162
  inputs = tokenizer(
163
+ prompt,
164
  return_tensors="pt",
165
  truncation=True,
166
+ max_length=1024,
167
  )
168
 
169
  input_length = inputs["input_ids"].shape[1]
170
 
171
+ # Stop en <|im_end|> para modelos NxG
172
+ stop_token_ids = [tokenizer.eos_token_id]
173
+ if req.model in CHATML_MODELS:
174
+ im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False)
175
+ if im_end:
176
+ stop_token_ids.append(im_end[0])
177
+
178
  with torch.no_grad():
179
  output = model.generate(
180
  **inputs,
 
183
  top_p=req.top_p,
184
  do_sample=True,
185
  pad_token_id=tokenizer.eos_token_id,
186
+ eos_token_id=stop_token_ids,
187
  repetition_penalty=1.1,
188
  )
189
 
 
196
  response=response_text.strip(),
197
  model=req.model,
198
  tokens_generated=len(new_tokens),
199
+ time_ms=elapsed_ms,
200
  )
201
 
202
  except Exception as e:
203
  raise HTTPException(status_code=500, detail=str(e))
204
+