maxime-antoine-dev commited on
Commit
1f23e23
Β·
1 Parent(s): 712c34b

added light mode

Browse files
Files changed (1) hide show
  1. main.py +239 -99
main.py CHANGED
@@ -2,8 +2,9 @@
2
  import os
3
  import json
4
  import time
 
5
  import asyncio
6
- from typing import Any, Dict, Optional
7
  from functools import lru_cache
8
 
9
  from fastapi import FastAPI
@@ -12,26 +13,51 @@ from huggingface_hub import hf_hub_download
12
  from llama_cpp import Llama
13
 
14
  # ----------------------------
15
- # Config
16
  # ----------------------------
17
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
18
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
19
 
20
- # llama.cpp params (CPU Space)
21
- N_CTX = int(os.getenv("N_CTX", "2048"))
22
- N_THREADS = int(os.getenv("N_THREADS", str(max(1, (os.cpu_count() or 2) - 1))))
 
 
23
  N_BATCH = int(os.getenv("N_BATCH", "256"))
24
 
25
- # generation defaults
26
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "180"))
27
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
28
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
29
 
30
- # One request at a time on CPU (prevents stalls / extreme latency)
 
 
 
 
 
 
 
 
31
  GEN_LOCK = asyncio.Lock()
32
 
 
 
33
  # ----------------------------
34
- # Prompt (aligned with your training target)
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # ----------------------------
36
  ALLOWED_LABELS = [
37
  "none",
@@ -51,66 +77,52 @@ ALLOWED_LABELS = [
51
  "intentional",
52
  ]
53
 
54
- def labels_block_compact() -> str:
55
- return "\n".join([f'- "{k}"' for k in ALLOWED_LABELS])
56
 
57
- INSTRUCTION = """You are a logical fallacy detection assistant.
58
 
59
  You MUST choose labels ONLY from this list (use the exact string):
60
- {labels_list}
61
 
62
- Return ONLY ONE valid JSON object with this schema:
63
  {{
64
  "has_fallacy": boolean,
65
  "fallacies": [
66
  {{
67
  "type": string,
68
- "confidence": number, // 0.0..1.0
69
- "evidence_quotes": [string], // exact substring(s) copied from the input text
70
- "rationale": string // specific to this fallacy + quote
71
  }}
72
  ],
73
- "overall_explanation": string // short summary across the whole input
74
  }}
75
 
76
- Hard rules:
77
- - Output ONLY the JSON object. No markdown. No extra text.
78
- - Produce exactly ONE JSON object, then STOP.
79
- - evidence_quotes MUST be exact substrings from the input text.
80
- - If has_fallacy=false:
81
- - fallacies MUST be []
82
- - overall_explanation MUST explicitly say there is no fallacy
83
- - overall_explanation MUST NOT mention any fallacy label/category names.
84
- - If has_fallacy=true:
85
- - fallacies MUST contain at least 1 item
86
- - EACH fallacies[i].type MUST be one of the allowed labels (NOT a synonym)
87
- """
88
-
89
- SYSTEM_PROMPT = "You are a careful JSON-only assistant. Output only JSON."
90
 
91
  def build_messages(text: str) -> list[dict]:
92
- instruction = INSTRUCTION.format(labels_list=labels_block_compact())
93
  return [
94
- {"role": "system", "content": SYSTEM_PROMPT},
95
- {"role": "user", "content": f"{instruction}\n\nTEXT:\n{text}\n\nJSON:"},
96
  ]
97
 
98
  # ----------------------------
99
- # Robust JSON extraction
100
  # ----------------------------
101
- def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
102
- start = s.find("{")
103
- if start == -1:
104
- return None
105
- end = s.rfind("}")
106
- if end == -1 or end <= start:
107
- return None
108
- cand = s[start : end + 1].strip()
109
- try:
110
- return json.loads(cand)
111
- except Exception:
112
- return None
113
 
 
 
 
114
  def stop_at_complete_json(text: str) -> Optional[str]:
115
  start = text.find("{")
116
  if start == -1:
@@ -142,73 +154,145 @@ def stop_at_complete_json(text: str) -> Optional[str]:
142
  return text[start : i + 1]
143
  return None
144
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  # ----------------------------
146
- # Load GGUF model (global)
147
  # ----------------------------
148
  llm: Optional[Llama] = None
149
  model_path: Optional[str] = None
 
 
150
 
151
- def load_llama() -> tuple[str, Llama]:
152
- global model_path
153
 
154
- t0 = time.time()
155
- mp = hf_hub_download(
156
- repo_id=GGUF_REPO_ID,
157
- filename=GGUF_FILENAME,
158
- token=os.getenv("HF_TOKEN"), # optional (only if repo is private)
159
- )
160
- t1 = time.time()
161
-
162
- # CPU Space -> n_gpu_layers = 0
163
- llama = Llama(
164
- model_path=mp,
165
- n_ctx=N_CTX,
166
- n_threads=N_THREADS,
167
- n_batch=N_BATCH,
168
- n_gpu_layers=0,
169
- verbose=True,
170
- )
171
- t2 = time.time()
172
 
173
- print(f"βœ… GGUF downloaded: {mp} ({t1 - t0:.1f}s)")
174
- print(f"βœ… Model loaded: ({t2 - t1:.1f}s) n_ctx={N_CTX} threads={N_THREADS} batch={N_BATCH}")
175
- model_path = mp
176
- return mp, llama
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # ----------------------------
179
- # FastAPI
180
- # ----------------------------
181
- app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
182
 
183
- class AnalyzeRequest(BaseModel):
184
- text: str
185
- max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT
186
- temperature: float = TEMPERATURE_DEFAULT
187
- top_p: float = TOP_P_DEFAULT
188
 
189
  @app.get("/health")
190
  def health():
191
  return {
192
- "ok": True,
193
- "engine": "llama.cpp (llama-cpp-python)",
 
194
  "gguf_repo": GGUF_REPO_ID,
195
  "gguf_filename": GGUF_FILENAME,
196
- "model_loaded": llm is not None,
197
  "model_path": model_path,
198
  "n_ctx": N_CTX,
199
  "n_threads": N_THREADS,
200
  "n_batch": N_BATCH,
 
201
  }
202
 
203
- @app.on_event("startup")
204
- def _startup():
205
- global llm
206
- _, llm_loaded = load_llama()
207
- llm = llm_loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
 
 
 
209
  @lru_cache(maxsize=256)
210
- def _cached_generate(text: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]:
211
- assert llm is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  messages = build_messages(text)
214
 
@@ -221,18 +305,74 @@ def _cached_generate(text: str, max_new_tokens: int, temperature: float, top_p:
221
  )
222
 
223
  raw = out["choices"][0]["message"]["content"]
224
-
225
- cut = stop_at_complete_json(raw)
226
- raw_cut = cut if cut is not None else raw
227
-
228
- obj = extract_first_json_obj(raw_cut)
229
  if obj is None:
230
- return {"ok": False, "raw": raw_cut}
231
 
232
  return {"ok": True, "result": obj}
233
 
234
  @app.post("/analyze")
235
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
236
- # CPU: serialize requests to keep stable latency
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  async with GEN_LOCK:
238
- return _cached_generate(req.text, int(req.max_new_tokens), float(req.temperature), float(req.top_p))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  import json
4
  import time
5
+ import uuid
6
  import asyncio
7
+ from typing import Any, Dict, Optional, Tuple
8
  from functools import lru_cache
9
 
10
  from fastapi import FastAPI
 
13
  from llama_cpp import Llama
14
 
15
  # ----------------------------
16
+ # Config (model)
17
  # ----------------------------
18
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
19
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
20
 
21
+ # Model load params (fixed once at startup)
22
+ # Keep these conservative for HF CPU
23
+ N_CTX = int(os.getenv("N_CTX", "1536"))
24
+ CPU_COUNT = os.cpu_count() or 4
25
+ N_THREADS = int(os.getenv("N_THREADS", str(min(8, max(1, CPU_COUNT - 1)))))
26
  N_BATCH = int(os.getenv("N_BATCH", "256"))
27
 
28
+ # Default generation params ("normal")
29
  MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "180"))
30
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
31
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
32
 
33
+ # "Light" generation params (fastest / most stable)
34
+ LIGHT_MAX_NEW_TOKENS = int(os.getenv("LIGHT_MAX_NEW_TOKENS", "60"))
35
+ LIGHT_TEMPERATURE = float(os.getenv("LIGHT_TEMPERATURE", "0.0"))
36
+ LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
37
+
38
+ # "Light" runtime knobs (do NOT reload model, just reduce work)
39
+ LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
40
+
41
+ # One request at a time on CPU
42
  GEN_LOCK = asyncio.Lock()
43
 
44
+ app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
45
+
46
  # ----------------------------
47
+ # Request model
48
+ # ----------------------------
49
+ class AnalyzeRequest(BaseModel):
50
+ text: str
51
+ # if True => use "light" parameters
52
+ light: bool = False
53
+
54
+ # optional overrides (applied after picking light/normal defaults)
55
+ max_new_tokens: Optional[int] = None
56
+ temperature: Optional[float] = None
57
+ top_p: Optional[float] = None
58
+
59
+ # ----------------------------
60
+ # Prompt
61
  # ----------------------------
62
  ALLOWED_LABELS = [
63
  "none",
 
77
  "intentional",
78
  ]
79
 
80
+ LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
 
81
 
82
+ PROMPT_TEMPLATE = f"""You are a logical fallacy detection assistant.
83
 
84
  You MUST choose labels ONLY from this list (use the exact string):
85
+ {LABELS_STR}
86
 
87
+ Return ONLY valid JSON with this schema:
88
  {{
89
  "has_fallacy": boolean,
90
  "fallacies": [
91
  {{
92
  "type": string,
93
+ "confidence": number,
94
+ "evidence_quotes": [string],
95
+ "rationale": string
96
  }}
97
  ],
98
+ "overall_explanation": string
99
  }}
100
 
101
+ Rules:
102
+ Output ONLY JSON. No markdown.
103
+ If no fallacy: has_fallacy=false and fallacies=[].
104
+
105
+ INPUT:
106
+ {{text}}
107
+
108
+ OUTPUT:"""
 
 
 
 
 
 
109
 
110
  def build_messages(text: str) -> list[dict]:
 
111
  return [
112
+ {"role": "system", "content": "Output only JSON. Produce exactly one JSON object and stop."},
113
+ {"role": "user", "content": PROMPT_TEMPLATE.replace("{text}", text)},
114
  ]
115
 
116
  # ----------------------------
117
+ # Logging helpers
118
  # ----------------------------
119
+ def _log(rid: str, msg: str):
120
+ # rid = request id to correlate logs
121
+ print(f"[{rid}] {msg}", flush=True)
 
 
 
 
 
 
 
 
 
122
 
123
+ # ----------------------------
124
+ # JSON extraction helpers
125
+ # ----------------------------
126
  def stop_at_complete_json(text: str) -> Optional[str]:
127
  start = text.find("{")
128
  if start == -1:
 
154
  return text[start : i + 1]
155
  return None
156
 
157
+ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
158
+ cut = stop_at_complete_json(s) or s
159
+ start = cut.find("{")
160
+ end = cut.rfind("}")
161
+ if start == -1 or end == -1 or end <= start:
162
+ return None
163
+ cand = cut[start : end + 1].strip()
164
+ try:
165
+ return json.loads(cand)
166
+ except Exception:
167
+ return None
168
+
169
  # ----------------------------
170
+ # Model load
171
  # ----------------------------
172
  llm: Optional[Llama] = None
173
  model_path: Optional[str] = None
174
+ load_error: Optional[str] = None
175
+ loaded_at_ts: Optional[float] = None
176
 
177
+ def load_llama() -> None:
178
+ global llm, model_path, load_error, loaded_at_ts
179
 
180
+ print("=== FADES startup ===", flush=True)
181
+ print(f"GGUF_REPO_ID={GGUF_REPO_ID}", flush=True)
182
+ print(f"GGUF_FILENAME={GGUF_FILENAME}", flush=True)
183
+ print(f"N_CTX={N_CTX} N_THREADS={N_THREADS} N_BATCH={N_BATCH}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ try:
186
+ t0 = time.time()
187
+ mp = hf_hub_download(
188
+ repo_id=GGUF_REPO_ID,
189
+ filename=GGUF_FILENAME,
190
+ token=os.getenv("HF_TOKEN"),
191
+ )
192
+ t1 = time.time()
193
+ print(f"βœ… GGUF downloaded: {mp} ({t1 - t0:.1f}s)", flush=True)
194
+
195
+ t2 = time.time()
196
+ llm_local = Llama(
197
+ model_path=mp,
198
+ n_ctx=N_CTX,
199
+ n_threads=N_THREADS,
200
+ n_batch=N_BATCH,
201
+ n_gpu_layers=0,
202
+ verbose=False,
203
+ )
204
+ t3 = time.time()
205
+ print(f"βœ… Model loaded: ({t3 - t2:.1f}s) n_ctx={N_CTX} threads={N_THREADS} batch={N_BATCH}", flush=True)
206
+
207
+ llm = llm_local
208
+ model_path = mp
209
+ load_error = None
210
+ loaded_at_ts = time.time()
211
+ print("=== Startup OK ===", flush=True)
212
+
213
+ except Exception as e:
214
+ load_error = repr(e)
215
+ print(f"❌ Startup FAILED: {load_error}", flush=True)
216
 
217
+ @app.on_event("startup")
218
+ def _startup():
219
+ load_llama()
 
220
 
221
+ @app.get("/")
222
+ def root():
223
+ return {"ok": True, "hint": "Use GET /health or POST /analyze"}
 
 
224
 
225
  @app.get("/health")
226
  def health():
227
  return {
228
+ "ok": llm is not None and load_error is None,
229
+ "model_loaded": llm is not None,
230
+ "load_error": load_error,
231
  "gguf_repo": GGUF_REPO_ID,
232
  "gguf_filename": GGUF_FILENAME,
 
233
  "model_path": model_path,
234
  "n_ctx": N_CTX,
235
  "n_threads": N_THREADS,
236
  "n_batch": N_BATCH,
237
+ "loaded_at_ts": loaded_at_ts,
238
  }
239
 
240
+ # ----------------------------
241
+ # Param selection (light vs normal)
242
+ # ----------------------------
243
+ def pick_params(req: AnalyzeRequest) -> Dict[str, Any]:
244
+ if req.light:
245
+ params = {
246
+ "max_new_tokens": LIGHT_MAX_NEW_TOKENS,
247
+ "temperature": LIGHT_TEMPERATURE,
248
+ "top_p": LIGHT_TOP_P,
249
+ "n_batch": LIGHT_N_BATCH,
250
+ }
251
+ else:
252
+ params = {
253
+ "max_new_tokens": MAX_NEW_TOKENS_DEFAULT,
254
+ "temperature": TEMPERATURE_DEFAULT,
255
+ "top_p": TOP_P_DEFAULT,
256
+ "n_batch": N_BATCH, # keep default
257
+ }
258
+
259
+ # Apply per-request overrides (if provided)
260
+ if req.max_new_tokens is not None:
261
+ params["max_new_tokens"] = int(req.max_new_tokens)
262
+ if req.temperature is not None:
263
+ params["temperature"] = float(req.temperature)
264
+ if req.top_p is not None:
265
+ params["top_p"] = float(req.top_p)
266
+
267
+ # Hard safety caps on CPU
268
+ params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 300))
269
+ params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
270
+ params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
271
+ params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
272
+
273
+ return params
274
 
275
+ # ----------------------------
276
+ # Cached generate - separated by mode + params
277
+ # ----------------------------
278
  @lru_cache(maxsize=256)
279
+ def _cached_generate(
280
+ text: str,
281
+ light: bool,
282
+ max_new_tokens: int,
283
+ temperature: float,
284
+ top_p: float,
285
+ n_batch: int,
286
+ ) -> Dict[str, Any]:
287
+ if llm is None:
288
+ return {"ok": False, "error": "model_not_loaded", "detail": load_error}
289
+
290
+ # Change batch for this call (llama-cpp-python supports runtime override)
291
+ # Some versions accept it; if yours doesn't, it will be ignored harmlessly.
292
+ try:
293
+ llm.n_batch = int(n_batch) # type: ignore[attr-defined]
294
+ except Exception:
295
+ pass
296
 
297
  messages = build_messages(text)
298
 
 
305
  )
306
 
307
  raw = out["choices"][0]["message"]["content"]
308
+ obj = extract_first_json_obj(raw)
 
 
 
 
309
  if obj is None:
310
+ return {"ok": False, "error": "json_parse_error", "raw": raw}
311
 
312
  return {"ok": True, "result": obj}
313
 
314
  @app.post("/analyze")
315
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
316
+ rid = uuid.uuid4().hex[:10]
317
+ t0 = time.time()
318
+
319
+ _log(rid, f"πŸ“© Request received (light={req.light}) chars={len(req.text)}")
320
+
321
+ if not req.text or not req.text.strip():
322
+ _log(rid, "⚠️ Empty text")
323
+ return {"ok": False, "error": "empty_text"}
324
+
325
+ params = pick_params(req)
326
+ _log(
327
+ rid,
328
+ f"βš™οΈ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
329
+ )
330
+
331
+ # serialize requests on CPU
332
  async with GEN_LOCK:
333
+ _log(rid, "πŸ”’ Acquired GEN_LOCK")
334
+ t_lock = time.time()
335
+
336
+ _log(rid, "🧱 Building prompt/messages")
337
+ t1 = time.time()
338
+
339
+ # Generate
340
+ _log(rid, "🧠 Generating...")
341
+ t2 = time.time()
342
+ res = _cached_generate(
343
+ req.text,
344
+ bool(req.light),
345
+ int(params["max_new_tokens"]),
346
+ float(params["temperature"]),
347
+ float(params["top_p"]),
348
+ int(params["n_batch"]),
349
+ )
350
+ t3 = time.time()
351
+
352
+ if not res.get("ok"):
353
+ _log(rid, f"❌ Generation failed: {res.get('error')}")
354
+ else:
355
+ _log(rid, "βœ… JSON parsed OK")
356
+
357
+ elapsed_total = time.time() - t0
358
+ elapsed_lock = time.time() - t_lock
359
+ _log(rid, f"⏱ Done. gen_time={t3 - t2:.2f}s total={elapsed_total:.2f}s (under lock {elapsed_lock:.2f}s)")
360
+
361
+ # return with timings
362
+ return {
363
+ **res,
364
+ "meta": {
365
+ "request_id": rid,
366
+ "light": bool(req.light),
367
+ "params": {
368
+ "max_new_tokens": int(params["max_new_tokens"]),
369
+ "temperature": float(params["temperature"]),
370
+ "top_p": float(params["top_p"]),
371
+ "n_batch": int(params["n_batch"]),
372
+ },
373
+ "timings_s": {
374
+ "total": round(elapsed_total, 3),
375
+ "gen": round(t3 - t2, 3),
376
+ },
377
+ },
378
+ }