khushalcodiste commited on
Commit
e6db69c
Β·
1 Parent(s): 0fe7743

feat: huh

Browse files
Files changed (1) hide show
  1. main.py +31 -31
main.py CHANGED
@@ -114,7 +114,35 @@ def load_model():
114
  logger.info("🧠 Loading ONNX decoder...")
115
  decoder_model_path = os.path.join(model_dir, "decoder_model_merged_q4.onnx")
116
 
117
- decoder_session is None or tokenizer is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  raise HTTPException(status_code=500, detail="Model not loaded")
119
 
120
  try:
@@ -182,7 +210,7 @@ def load_model():
182
  # Update past_key_values from outputs
183
  for j, output in enumerate(decoder_session.get_outputs()):
184
  if output.name.startswith("past_key"):
185
- if jdecoder_sessionen(outputs):
186
  past_key_values[output.name] = outputs[j]
187
 
188
  # Check for EOS token
@@ -194,34 +222,6 @@ def load_model():
194
 
195
  logger.info("βœ… Generation successful")
196
 
197
-
198
- # =========================
199
- # πŸ“€ RESPONSE ENDPOINT
200
- # =========================
201
- @app.post("/generate")
202
- async def generate(req: GenerateRequest):
203
- if model is None or tokenizer is None:
204
- raise HTTPException(status_code=500, detail="Model not loaded")
205
-
206
- try:
207
- logger.info(f"🧠 Generating for prompt: {req.prompt[:50]}...")
208
-
209
- # Tokenize input
210
- inputs = tokenizer(req.prompt, return_tensors="pt")
211
-
212
- # Generate
213
- outputs = model.generate(
214
- **inputs,
215
- max_new_tokens=req.max_tokens,
216
- temperature=req.temperature,
217
- do_sample=True,
218
- )
219
-
220
- # Decode output
221
- result = tokenizer.decode(outputs[0], skip_special_tokens=True)
222
-
223
- logger.info("βœ… Generation successful")
224
-
225
  return {
226
  "success": True,
227
  "response": result
@@ -240,7 +240,7 @@ async def generate(req: GenerateRequest):
240
  async def health():
241
  return {
242
  "status": "ok",
243
- "model_loaded": pipe is not None
244
  }
245
 
246
 
 
114
  logger.info("🧠 Loading ONNX decoder...")
115
  decoder_model_path = os.path.join(model_dir, "decoder_model_merged_q4.onnx")
116
 
117
+ providers = ["CPUExecutionProvider"]
118
+ try:
119
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
120
+ except:
121
+ pass
122
+
123
+ decoder_session = ort.InferenceSession(decoder_model_path, providers=providers)
124
+ logger.info(f"βœ… Model loaded successfully (using {decoder_session.get_providers()})")
125
+
126
+ except Exception as e:
127
+ logger.exception("❌ Failed to load model")
128
+ raise e
129
+
130
+
131
+ # =========================
132
+ # πŸ“₯ REQUEST MODEL
133
+ # =========================
134
+ class GenerateRequest(BaseModel):
135
+ prompt: str
136
+ max_tokens: Optional[int] = 100
137
+ temperature: Optional[float] = 0.7
138
+
139
+
140
+ # =========================
141
+ # πŸ“€ RESPONSE ENDPOINT
142
+ # =========================
143
+ @app.post("/generate")
144
+ async def generate(req: GenerateRequest):
145
+ if decoder_session is None or tokenizer is None:
146
  raise HTTPException(status_code=500, detail="Model not loaded")
147
 
148
  try:
 
210
  # Update past_key_values from outputs
211
  for j, output in enumerate(decoder_session.get_outputs()):
212
  if output.name.startswith("past_key"):
213
+ if j < len(outputs):
214
  past_key_values[output.name] = outputs[j]
215
 
216
  # Check for EOS token
 
222
 
223
  logger.info("βœ… Generation successful")
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  return {
226
  "success": True,
227
  "response": result
 
240
  async def health():
241
  return {
242
  "status": "ok",
243
+ "model_loaded": decoder_session is not None
244
  }
245
 
246