maxime-antoine-dev commited on
Commit
66ca5c9
·
1 Parent(s): 8d0988b

Added rewirte route, improved prompts

Browse files
Files changed (1) hide show
  1. main.py +450 -82
main.py CHANGED
@@ -4,23 +4,23 @@ import json
4
  import time
5
  import uuid
6
  import asyncio
7
- from typing import Any, Dict, Optional, Tuple
8
  from functools import lru_cache
9
 
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
12
- from pydantic import BaseModel
13
  from huggingface_hub import hf_hub_download
14
  from llama_cpp import Llama
15
 
16
- # ----------------------------
 
17
  # Config (model)
18
- # ----------------------------
19
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
20
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
21
 
22
  # Model load params (fixed once at startup)
23
- # Keep these conservative for HF CPU
24
  N_CTX = int(os.getenv("N_CTX", "1536"))
25
  CPU_COUNT = os.cpu_count() or 4
26
  N_THREADS = int(os.getenv("N_THREADS", str(min(8, max(1, CPU_COUNT - 1)))))
@@ -31,12 +31,12 @@ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "180"))
31
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
32
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
33
 
34
- # "Light" generation params (fastest / most stable)
35
  LIGHT_MAX_NEW_TOKENS = int(os.getenv("LIGHT_MAX_NEW_TOKENS", "60"))
36
  LIGHT_TEMPERATURE = float(os.getenv("LIGHT_TEMPERATURE", "0.0"))
37
  LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
38
 
39
- # "Light" runtime knobs (do NOT reload model, just reduce work)
40
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
41
 
42
  # One request at a time on CPU
@@ -44,17 +44,16 @@ GEN_LOCK = asyncio.Lock()
44
 
45
  app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
46
 
47
- # ----------------------------
 
48
  # CORS (for browser front-ends)
49
- # ----------------------------
50
- # Comma-separated list of allowed origins, or "*" to allow all.
51
  _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
52
  if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
53
  allow_origins = ["*"]
54
  else:
55
  allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()]
56
 
57
- # Note: when allow_origins=["*"], allow_credentials must be False.
58
  app.add_middleware(
59
  CORSMiddleware,
60
  allow_origins=allow_origins,
@@ -63,22 +62,34 @@ app.add_middleware(
63
  allow_headers=["*"],
64
  )
65
 
66
- # ----------------------------
67
- # Request model
68
- # ----------------------------
69
- class AnalyzeRequest(BaseModel):
70
- text: str
71
  # if True => use "light" parameters
72
  light: bool = False
73
-
74
  # optional overrides (applied after picking light/normal defaults)
75
  max_new_tokens: Optional[int] = None
76
  temperature: Optional[float] = None
77
  top_p: Optional[float] = None
78
 
79
- # ----------------------------
80
- # Prompt
81
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ALLOWED_LABELS = [
83
  "none",
84
  "faulty generalization",
@@ -99,12 +110,13 @@ ALLOWED_LABELS = [
99
 
100
  LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
101
 
102
- PROMPT_TEMPLATE = f"""You are a logical fallacy detection assistant.
 
103
 
104
- You MUST choose labels ONLY from this list (use the exact string):
105
  {LABELS_STR}
106
 
107
- Return ONLY valid JSON with this schema:
108
  {{
109
  "has_fallacy": boolean,
110
  "fallacies": [
@@ -118,31 +130,96 @@ Return ONLY valid JSON with this schema:
118
  "overall_explanation": string
119
  }}
120
 
121
- Rules:
122
- Output ONLY JSON. No markdown.
123
- If no fallacy: has_fallacy=false and fallacies=[].
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  INPUT:
126
  {{text}}
127
 
128
  OUTPUT:"""
129
 
130
- def build_messages(text: str) -> list[dict]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  return [
132
- {"role": "system", "content": "Output only JSON. Produce exactly one JSON object and stop."},
133
- {"role": "user", "content": PROMPT_TEMPLATE.replace("{text}", text)},
134
  ]
135
 
136
- # ----------------------------
137
- # Logging helpers
138
- # ----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def _log(rid: str, msg: str):
140
- # rid = request id to correlate logs
141
  print(f"[{rid}] {msg}", flush=True)
142
 
143
- # ----------------------------
144
- # JSON extraction helpers
145
- # ----------------------------
 
146
  def stop_at_complete_json(text: str) -> Optional[str]:
147
  start = text.find("{")
148
  if start == -1:
@@ -174,6 +251,7 @@ def stop_at_complete_json(text: str) -> Optional[str]:
174
  return text[start : i + 1]
175
  return None
176
 
 
177
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
178
  cut = stop_at_complete_json(s) or s
179
  start = cut.find("{")
@@ -186,14 +264,16 @@ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
186
  except Exception:
187
  return None
188
 
189
- # ----------------------------
 
190
  # Model load
191
- # ----------------------------
192
  llm: Optional[Llama] = None
193
  model_path: Optional[str] = None
194
  load_error: Optional[str] = None
195
  loaded_at_ts: Optional[float] = None
196
 
 
197
  def load_llama() -> None:
198
  global llm, model_path, load_error, loaded_at_ts
199
 
@@ -234,13 +314,16 @@ def load_llama() -> None:
234
  load_error = repr(e)
235
  print(f"❌ Startup FAILED: {load_error}", flush=True)
236
 
 
237
  @app.on_event("startup")
238
  def _startup():
239
  load_llama()
240
 
 
241
  @app.get("/")
242
  def root():
243
- return {"ok": True, "hint": "Use GET /health or POST /analyze"}
 
244
 
245
  @app.get("/health")
246
  def health():
@@ -257,10 +340,11 @@ def health():
257
  "loaded_at_ts": loaded_at_ts,
258
  }
259
 
260
- # ----------------------------
261
- # Param selection (light vs normal)
262
- # ----------------------------
263
- def pick_params(req: AnalyzeRequest) -> Dict[str, Any]:
 
264
  if req.light:
265
  params = {
266
  "max_new_tokens": LIGHT_MAX_NEW_TOKENS,
@@ -273,10 +357,9 @@ def pick_params(req: AnalyzeRequest) -> Dict[str, Any]:
273
  "max_new_tokens": MAX_NEW_TOKENS_DEFAULT,
274
  "temperature": TEMPERATURE_DEFAULT,
275
  "top_p": TOP_P_DEFAULT,
276
- "n_batch": N_BATCH, # keep default
277
  }
278
 
279
- # Apply per-request overrides (if provided)
280
  if req.max_new_tokens is not None:
281
  params["max_new_tokens"] = int(req.max_new_tokens)
282
  if req.temperature is not None:
@@ -284,20 +367,115 @@ def pick_params(req: AnalyzeRequest) -> Dict[str, Any]:
284
  if req.top_p is not None:
285
  params["top_p"] = float(req.top_p)
286
 
287
- # Hard safety caps on CPU
288
- params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 300))
289
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
290
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
291
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
292
-
293
  return params
294
 
295
- # ----------------------------
296
- # Cached generate - separated by mode + params
297
- # ----------------------------
298
- @lru_cache(maxsize=256)
299
- def _cached_generate(
300
- text: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  light: bool,
302
  max_new_tokens: int,
303
  temperature: float,
@@ -307,14 +485,27 @@ def _cached_generate(
307
  if llm is None:
308
  return {"ok": False, "error": "model_not_loaded", "detail": load_error}
309
 
310
- # Change batch for this call (llama-cpp-python supports runtime override)
311
- # Some versions accept it; if yours doesn't, it will be ignored harmlessly.
312
  try:
313
  llm.n_batch = int(n_batch) # type: ignore[attr-defined]
314
  except Exception:
315
  pass
316
 
317
- messages = build_messages(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  out = llm.create_chat_completion(
320
  messages=messages,
@@ -331,15 +522,44 @@ def _cached_generate(
331
 
332
  return {"ok": True, "result": obj}
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  @app.post("/analyze")
335
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
336
  rid = uuid.uuid4().hex[:10]
337
  t0 = time.time()
338
 
339
- _log(rid, f"📩 Request received (light={req.light}) chars={len(req.text)}")
340
 
341
  if not req.text or not req.text.strip():
342
- _log(rid, "⚠️ Empty text")
343
  return {"ok": False, "error": "empty_text"}
344
 
345
  params = pick_params(req)
@@ -348,37 +568,29 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
348
  f"⚙️ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
349
  )
350
 
351
- # serialize requests on CPU
 
352
  async with GEN_LOCK:
353
- _log(rid, "🔒 Acquired GEN_LOCK")
354
  t_lock = time.time()
355
 
356
- _log(rid, "🧱 Building prompt/messages")
357
- t1 = time.time()
358
-
359
- # Generate
360
- _log(rid, "🧠 Generating...")
361
- t2 = time.time()
362
- res = _cached_generate(
363
- req.text,
364
  bool(req.light),
365
  int(params["max_new_tokens"]),
366
  float(params["temperature"]),
367
  float(params["top_p"]),
368
  int(params["n_batch"]),
369
  )
370
- t3 = time.time()
371
 
372
- if not res.get("ok"):
373
- _log(rid, f"❌ Generation failed: {res.get('error')}")
374
- else:
375
- _log(rid, "✅ JSON parsed OK")
376
 
377
- elapsed_total = time.time() - t0
378
- elapsed_lock = time.time() - t_lock
379
- _log(rid, f"⏱ Done. gen_time={t3 - t2:.2f}s total={elapsed_total:.2f}s (under lock {elapsed_lock:.2f}s)")
380
-
381
- # return with timings
382
  return {
383
  **res,
384
  "meta": {
@@ -390,9 +602,165 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
390
  "top_p": float(params["top_p"]),
391
  "n_batch": int(params["n_batch"]),
392
  },
393
- "timings_s": {
394
- "total": round(elapsed_total, 3),
395
- "gen": round(t3 - t2, 3),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  },
 
397
  },
398
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import time
5
  import uuid
6
  import asyncio
7
+ from typing import Any, Dict, Optional, List
8
  from functools import lru_cache
9
 
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
12
+ from pydantic import BaseModel, Field
13
  from huggingface_hub import hf_hub_download
14
  from llama_cpp import Llama
15
 
16
+
17
+ # ============================
18
  # Config (model)
19
+ # ============================
20
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
21
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
22
 
23
  # Model load params (fixed once at startup)
 
24
  N_CTX = int(os.getenv("N_CTX", "1536"))
25
  CPU_COUNT = os.cpu_count() or 4
26
  N_THREADS = int(os.getenv("N_THREADS", str(min(8, max(1, CPU_COUNT - 1)))))
 
31
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
32
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
33
 
34
+ # "Light" generation params
35
  LIGHT_MAX_NEW_TOKENS = int(os.getenv("LIGHT_MAX_NEW_TOKENS", "60"))
36
  LIGHT_TEMPERATURE = float(os.getenv("LIGHT_TEMPERATURE", "0.0"))
37
  LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
38
 
39
+ # "Light" runtime knobs
40
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
41
 
42
  # One request at a time on CPU
 
44
 
45
  app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
46
 
47
+
48
+ # ============================
49
  # CORS (for browser front-ends)
50
+ # ============================
 
51
  _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
52
  if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
53
  allow_origins = ["*"]
54
  else:
55
  allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()]
56
 
 
57
  app.add_middleware(
58
  CORSMiddleware,
59
  allow_origins=allow_origins,
 
62
  allow_headers=["*"],
63
  )
64
 
65
+
66
+ # ============================
67
+ # Schemas
68
+ # ============================
69
+ class GenParams(BaseModel):
70
  # if True => use "light" parameters
71
  light: bool = False
 
72
  # optional overrides (applied after picking light/normal defaults)
73
  max_new_tokens: Optional[int] = None
74
  temperature: Optional[float] = None
75
  top_p: Optional[float] = None
76
 
77
+
78
+ class AnalyzeRequest(GenParams):
79
+ text: str
80
+
81
+
82
+ class RewriteRequest(GenParams):
83
+ text: str
84
+ quote: str = Field(..., description="Verbatim substring that must be replaced.")
85
+ fallacy_type: str = Field(..., description="Fallacy type of the quote.")
86
+ rationale: str = Field(..., description="Why the quote is fallacious.")
87
+ occurrence: int = Field(0, description="Which occurrence of quote to replace (0-based).")
88
+
89
+
90
+ # ============================
91
+ # Labels & Prompts
92
+ # ============================
93
  ALLOWED_LABELS = [
94
  "none",
95
  "faulty generalization",
 
110
 
111
  LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
112
 
113
+ # Stronger /analyze prompt: forces specificity and forbids the "template" sentence
114
+ ANALYZE_PROMPT = f"""You are a fallacy detection assistant.
115
 
116
+ You MUST choose labels ONLY from this list (exact string):
117
  {LABELS_STR}
118
 
119
+ You MUST return ONLY valid JSON with this schema:
120
  {{
121
  "has_fallacy": boolean,
122
  "fallacies": [
 
130
  "overall_explanation": string
131
  }}
132
 
133
+ Hard rules:
134
+ - Output ONLY JSON. No markdown. No extra text.
135
+ - evidence_quotes MUST be verbatim substrings copied from the input text (no paraphrase).
136
+ - Keep each evidence quote short (prefer 1–2 sentences; max 240 chars).
137
+ - confidence MUST be a real probability between 0.0 and 1.0 (use 2 decimals).
138
+ It MUST NOT be always the same across examples. Calibrate it:
139
+ * 0.90–1.00: very explicit, unambiguous match, clear cue words.
140
+ * 0.70–0.89: strong match but some ambiguity or missing premise.
141
+ * 0.40–0.69: plausible but weak/partial evidence.
142
+ * 0.10–0.39: very uncertain.
143
+ - The rationale MUST be specific to the evidence (2–4 sentences):
144
+ Explain (1) what the quote claims, (2) why that matches the fallacy label,
145
+ (3) what logical step is invalid or missing.
146
+ DO NOT use generic filler. Do NOT reuse stock phrases.
147
+ In particular, you MUST NOT output this sentence:
148
+ "The input contains fallacious reasoning consistent with the predicted type(s)."
149
+ - overall_explanation MUST also be specific (2–5 sentences): summarize the reasoning issues and reference the key cue(s).
150
+ - If no fallacy: has_fallacy=false and fallacies=[] and overall_explanation explains briefly why.
151
 
152
  INPUT:
153
  {{text}}
154
 
155
  OUTPUT:"""
156
 
157
+ # /rewrite prompt: returns ONLY a replacement substring for the quote (server does the replacement)
158
+ REWRITE_PROMPT = """You are rewriting a small quoted span inside a larger text.
159
+
160
+ Goal:
161
+ - You MUST propose a replacement for the QUOTE only.
162
+ - The replacement should remove the fallacious reasoning described, while keeping the same tone/style/tense/entities.
163
+ - The replacement MUST be plausible in the surrounding context and should be similar length (roughly +/- 40%).
164
+ - Do NOT change anything outside the quote. Do NOT add new facts not implied by the original.
165
+ - Do NOT introduce new fallacies.
166
+
167
+ Return ONLY valid JSON with this schema:
168
+ {
169
+ "replacement_quote": string,
170
+ "why_this_fix": string
171
+ }
172
+
173
+ Hard rules:
174
+ - Output ONLY JSON. No markdown. No extra text.
175
+ - replacement_quote should be standalone text (no surrounding quotes).
176
+ - why_this_fix: 1–3 sentences, specific.
177
+
178
+ INPUT_TEXT:
179
+ {text}
180
+
181
+ QUOTE_TO_REWRITE:
182
+ {quote}
183
+
184
+ FALLACY_TYPE:
185
+ {fallacy_type}
186
+
187
+ WHY_FALLACIOUS:
188
+ {rationale}
189
+
190
+ OUTPUT:"""
191
+
192
+
193
+ def build_analyze_messages(text: str) -> List[Dict[str, str]]:
194
  return [
195
+ {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
196
+ {"role": "user", "content": ANALYZE_PROMPT.replace("{text}", text)},
197
  ]
198
 
199
+
200
+ def build_rewrite_messages(text: str, quote: str, fallacy_type: str, rationale: str) -> List[Dict[str, str]]:
201
+ prompt = REWRITE_PROMPT.format(
202
+ text=text,
203
+ quote=quote,
204
+ fallacy_type=fallacy_type,
205
+ rationale=rationale,
206
+ )
207
+ return [
208
+ {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
209
+ {"role": "user", "content": prompt},
210
+ ]
211
+
212
+
213
+ # ============================
214
+ # Logging
215
+ # ============================
216
  def _log(rid: str, msg: str):
 
217
  print(f"[{rid}] {msg}", flush=True)
218
 
219
+
220
+ # ============================
221
+ # Robust JSON extraction
222
+ # ============================
223
  def stop_at_complete_json(text: str) -> Optional[str]:
224
  start = text.find("{")
225
  if start == -1:
 
251
  return text[start : i + 1]
252
  return None
253
 
254
+
255
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
256
  cut = stop_at_complete_json(s) or s
257
  start = cut.find("{")
 
264
  except Exception:
265
  return None
266
 
267
+
268
+ # ============================
269
  # Model load
270
+ # ============================
271
  llm: Optional[Llama] = None
272
  model_path: Optional[str] = None
273
  load_error: Optional[str] = None
274
  loaded_at_ts: Optional[float] = None
275
 
276
+
277
  def load_llama() -> None:
278
  global llm, model_path, load_error, loaded_at_ts
279
 
 
314
  load_error = repr(e)
315
  print(f"❌ Startup FAILED: {load_error}", flush=True)
316
 
317
+
318
  @app.on_event("startup")
319
  def _startup():
320
  load_llama()
321
 
322
+
323
  @app.get("/")
324
  def root():
325
+ return {"ok": True, "hint": "Use GET /health, POST /analyze, POST /rewrite"}
326
+
327
 
328
  @app.get("/health")
329
  def health():
 
340
  "loaded_at_ts": loaded_at_ts,
341
  }
342
 
343
+
344
+ # ============================
345
+ # Param selection
346
+ # ============================
347
+ def pick_params(req: GenParams) -> Dict[str, Any]:
348
  if req.light:
349
  params = {
350
  "max_new_tokens": LIGHT_MAX_NEW_TOKENS,
 
357
  "max_new_tokens": MAX_NEW_TOKENS_DEFAULT,
358
  "temperature": TEMPERATURE_DEFAULT,
359
  "top_p": TOP_P_DEFAULT,
360
+ "n_batch": N_BATCH,
361
  }
362
 
 
363
  if req.max_new_tokens is not None:
364
  params["max_new_tokens"] = int(req.max_new_tokens)
365
  if req.temperature is not None:
 
367
  if req.top_p is not None:
368
  params["top_p"] = float(req.top_p)
369
 
370
+ # Safety caps
371
+ params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 400))
372
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
373
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
374
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
 
375
  return params
376
 
377
+
378
+ # ============================
379
+ # Output sanitation / validation
380
+ # ============================
381
+ def _clamp01(x: Any, default: float = 0.5) -> float:
382
+ try:
383
+ v = float(x)
384
+ except Exception:
385
+ return default
386
+ if v < 0.0:
387
+ return 0.0
388
+ if v > 1.0:
389
+ return 1.0
390
+ return v
391
+
392
+
393
+ def _is_allowed_label(lbl: Any) -> bool:
394
+ return isinstance(lbl, str) and lbl in ALLOWED_LABELS and lbl != "none"
395
+
396
+
397
+ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
398
+ """
399
+ Enforce shape, clamp confidence, drop invalid labels,
400
+ enforce evidence_quotes being substrings.
401
+ """
402
+ has_fallacy = bool(obj.get("has_fallacy", False))
403
+ fallacies_in = obj.get("fallacies", [])
404
+ if not isinstance(fallacies_in, list):
405
+ fallacies_in = []
406
+
407
+ fallacies_out = []
408
+ for f in fallacies_in:
409
+ if not isinstance(f, dict):
410
+ continue
411
+ f_type = f.get("type")
412
+ if not _is_allowed_label(f_type):
413
+ continue
414
+
415
+ conf = _clamp01(f.get("confidence", 0.5))
416
+ # keep 2 decimals for nicer UI
417
+ conf = float(f"{conf:.2f}")
418
+
419
+ ev = f.get("evidence_quotes", [])
420
+ if not isinstance(ev, list):
421
+ ev = []
422
+ ev_clean: List[str] = []
423
+ for q in ev:
424
+ if not isinstance(q, str):
425
+ continue
426
+ qq = q.strip()
427
+ if not qq:
428
+ continue
429
+ # evidence MUST be substring
430
+ if qq in input_text:
431
+ # keep short, but don't hard-cut if it breaks substring matching
432
+ if len(qq) <= 240:
433
+ ev_clean.append(qq)
434
+ else:
435
+ # if too long, try to keep first 240 if still substring (rare); else keep as-is
436
+ short = qq[:240]
437
+ if short in input_text:
438
+ ev_clean.append(short)
439
+ else:
440
+ ev_clean.append(qq)
441
+
442
+ rationale = f.get("rationale")
443
+ if not isinstance(rationale, str):
444
+ rationale = ""
445
+ rationale = rationale.strip()
446
+
447
+ fallacies_out.append(
448
+ {
449
+ "type": f_type,
450
+ "confidence": conf,
451
+ "evidence_quotes": ev_clean[:3],
452
+ "rationale": rationale,
453
+ }
454
+ )
455
+
456
+ overall = obj.get("overall_explanation")
457
+ if not isinstance(overall, str):
458
+ overall = ""
459
+ overall = overall.strip()
460
+
461
+ # If no fallacies survived sanitation, force no-fallacy state
462
+ if len(fallacies_out) == 0:
463
+ has_fallacy = False
464
+
465
+ return {
466
+ "has_fallacy": has_fallacy,
467
+ "fallacies": fallacies_out,
468
+ "overall_explanation": overall,
469
+ }
470
+
471
+
472
+ # ============================
473
+ # Cached generation (task-aware)
474
+ # ============================
475
+ @lru_cache(maxsize=512)
476
+ def _cached_chat_completion(
477
+ task: str,
478
+ payload: str,
479
  light: bool,
480
  max_new_tokens: int,
481
  temperature: float,
 
485
  if llm is None:
486
  return {"ok": False, "error": "model_not_loaded", "detail": load_error}
487
 
 
 
488
  try:
489
  llm.n_batch = int(n_batch) # type: ignore[attr-defined]
490
  except Exception:
491
  pass
492
 
493
+ try:
494
+ data = json.loads(payload)
495
+ except Exception:
496
+ return {"ok": False, "error": "bad_payload"}
497
+
498
+ if task == "analyze":
499
+ messages = build_analyze_messages(data["text"])
500
+ elif task == "rewrite":
501
+ messages = build_rewrite_messages(
502
+ data["text"],
503
+ data["quote"],
504
+ data["fallacy_type"],
505
+ data["rationale"],
506
+ )
507
+ else:
508
+ return {"ok": False, "error": "unknown_task"}
509
 
510
  out = llm.create_chat_completion(
511
  messages=messages,
 
522
 
523
  return {"ok": True, "result": obj}
524
 
525
+
526
+ def _occurrence_index(text: str, sub: str, occurrence: int) -> int:
527
+ if occurrence < 0:
528
+ return -1
529
+ start = 0
530
+ for _ in range(occurrence + 1):
531
+ idx = text.find(sub, start)
532
+ if idx == -1:
533
+ return -1
534
+ start = idx + max(1, len(sub))
535
+ return idx
536
+
537
+
538
+ def _replace_nth(text: str, old: str, new: str, occurrence: int) -> Dict[str, Any]:
539
+ idx = _occurrence_index(text, old, occurrence)
540
+ if idx == -1:
541
+ return {"ok": False, "error": "quote_not_found"}
542
+ return {
543
+ "ok": True,
544
+ "rewritten_text": text[:idx] + new + text[idx + len(old) :],
545
+ "start_char": idx,
546
+ "end_char": idx + len(new),
547
+ "old_start_char": idx,
548
+ "old_end_char": idx + len(old),
549
+ }
550
+
551
+
552
+ # ============================
553
+ # Routes
554
+ # ============================
555
  @app.post("/analyze")
556
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
557
  rid = uuid.uuid4().hex[:10]
558
  t0 = time.time()
559
 
560
+ _log(rid, f"📩 /analyze received (light={req.light}) chars={len(req.text) if req.text else 0}")
561
 
562
  if not req.text or not req.text.strip():
 
563
  return {"ok": False, "error": "empty_text"}
564
 
565
  params = pick_params(req)
 
568
  f"⚙️ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
569
  )
570
 
571
+ payload = json.dumps({"text": req.text}, ensure_ascii=False)
572
+
573
  async with GEN_LOCK:
 
574
  t_lock = time.time()
575
 
576
+ _log(rid, "🧠 Generating analyze...")
577
+ t_gen0 = time.time()
578
+ res = _cached_chat_completion(
579
+ "analyze",
580
+ payload,
 
 
 
581
  bool(req.light),
582
  int(params["max_new_tokens"]),
583
  float(params["temperature"]),
584
  float(params["top_p"]),
585
  int(params["n_batch"]),
586
  )
587
+ t_gen1 = time.time()
588
 
589
+ elapsed_total = time.time() - t0
590
+ elapsed_lock = time.time() - t_lock
 
 
591
 
592
+ if not res.get("ok"):
593
+ _log(rid, f"❌ /analyze failed: {res.get('error')}")
 
 
 
594
  return {
595
  **res,
596
  "meta": {
 
602
  "top_p": float(params["top_p"]),
603
  "n_batch": int(params["n_batch"]),
604
  },
605
+ "timings_s": {"total": round(elapsed_total, 3), "gen": round(t_gen1 - t_gen0, 3)},
606
+ },
607
+ }
608
+
609
+ # sanitize output for stability (substrings, labels, confidence clamp)
610
+ clean = sanitize_analyze_output(res["result"], req.text)
611
+
612
+ _log(rid, f"✅ /analyze ok fallacies={len(clean.get('fallacies', []))} total={elapsed_total:.2f}s")
613
+ return {
614
+ "ok": True,
615
+ "result": clean,
616
+ "meta": {
617
+ "request_id": rid,
618
+ "light": bool(req.light),
619
+ "params": {
620
+ "max_new_tokens": int(params["max_new_tokens"]),
621
+ "temperature": float(params["temperature"]),
622
+ "top_p": float(params["top_p"]),
623
+ "n_batch": int(params["n_batch"]),
624
+ },
625
+ "timings_s": {
626
+ "total": round(elapsed_total, 3),
627
+ "gen": round(t_gen1 - t_gen0, 3),
628
+ "under_lock": round(elapsed_lock, 3),
629
+ },
630
+ },
631
+ }
632
+
633
+
634
+ @app.post("/rewrite")
635
+ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
636
+ rid = uuid.uuid4().hex[:10]
637
+ t0 = time.time()
638
+
639
+ _log(
640
+ rid,
641
+ f"📩 /rewrite received (light={req.light}) text_chars={len(req.text) if req.text else 0} quote_chars={len(req.quote) if req.quote else 0}",
642
+ )
643
+
644
+ if not req.text or not req.text.strip():
645
+ return {"ok": False, "error": "empty_text"}
646
+ if not req.quote or not req.quote.strip():
647
+ return {"ok": False, "error": "empty_quote"}
648
+
649
+ quote = req.quote.strip()
650
+ occurrence = int(req.occurrence or 0)
651
+
652
+ # validate quote existence early
653
+ if _occurrence_index(req.text, quote, occurrence) == -1:
654
+ return {"ok": False, "error": "quote_not_found", "detail": {"occurrence": occurrence}}
655
+
656
+ params = pick_params(req)
657
+ # rewrite generally needs a bit more room than light analyze if you want fluent replacements
658
+ # (still controllable by request overrides)
659
+ if req.light and req.max_new_tokens is None:
660
+ params["max_new_tokens"] = max(params["max_new_tokens"], 80)
661
+
662
+ _log(
663
+ rid,
664
+ f"⚙️ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
665
+ )
666
+
667
+ payload = json.dumps(
668
+ {
669
+ "text": req.text,
670
+ "quote": quote,
671
+ "fallacy_type": req.fallacy_type,
672
+ "rationale": req.rationale,
673
+ },
674
+ ensure_ascii=False,
675
+ )
676
+
677
+ async with GEN_LOCK:
678
+ t_lock = time.time()
679
+
680
+ _log(rid, "🧠 Generating rewrite replacement_quote...")
681
+ t_gen0 = time.time()
682
+ res = _cached_chat_completion(
683
+ "rewrite",
684
+ payload,
685
+ bool(req.light),
686
+ int(params["max_new_tokens"]),
687
+ float(params["temperature"]),
688
+ float(params["top_p"]),
689
+ int(params["n_batch"]),
690
+ )
691
+ t_gen1 = time.time()
692
+
693
+ elapsed_total = time.time() - t0
694
+ elapsed_lock = time.time() - t_lock
695
+
696
+ if not res.get("ok"):
697
+ _log(rid, f"❌ /rewrite failed: {res.get('error')}")
698
+ return {
699
+ **res,
700
+ "meta": {
701
+ "request_id": rid,
702
+ "light": bool(req.light),
703
+ "params": {
704
+ "max_new_tokens": int(params["max_new_tokens"]),
705
+ "temperature": float(params["temperature"]),
706
+ "top_p": float(params["top_p"]),
707
+ "n_batch": int(params["n_batch"]),
708
  },
709
+ "timings_s": {"total": round(elapsed_total, 3), "gen": round(t_gen1 - t_gen0, 3)},
710
  },
711
  }
712
+
713
+ obj = res["result"]
714
+ if not isinstance(obj, dict):
715
+ return {"ok": False, "error": "bad_rewrite_output"}
716
+
717
+ replacement = obj.get("replacement_quote")
718
+ if not isinstance(replacement, str):
719
+ return {"ok": False, "error": "missing_replacement_quote", "raw": obj}
720
+
721
+ replacement = replacement.strip()
722
+ if not replacement:
723
+ return {"ok": False, "error": "empty_replacement_quote", "raw": obj}
724
+
725
+ why = obj.get("why_this_fix")
726
+ if not isinstance(why, str):
727
+ why = ""
728
+ why = why.strip()
729
+
730
+ # server-side enforced: ONLY the quote is changed
731
+ rep = _replace_nth(req.text, quote, replacement, occurrence)
732
+ if not rep.get("ok"):
733
+ return {"ok": False, "error": rep.get("error", "replace_failed")}
734
+
735
+ _log(rid, f"✅ /rewrite ok total={elapsed_total:.2f}s")
736
+ return {
737
+ "ok": True,
738
+ "result": {
739
+ "rewritten_text": rep["rewritten_text"],
740
+ "old_quote": quote,
741
+ "replacement_quote": replacement,
742
+ "why_this_fix": why,
743
+ "occurrence": occurrence,
744
+ "span": {
745
+ "old_start_char": rep["old_start_char"],
746
+ "old_end_char": rep["old_end_char"],
747
+ "new_start_char": rep["start_char"],
748
+ "new_end_char": rep["end_char"],
749
+ },
750
+ },
751
+ "meta": {
752
+ "request_id": rid,
753
+ "light": bool(req.light),
754
+ "params": {
755
+ "max_new_tokens": int(params["max_new_tokens"]),
756
+ "temperature": float(params["temperature"]),
757
+ "top_p": float(params["top_p"]),
758
+ "n_batch": int(params["n_batch"]),
759
+ },
760
+ "timings_s": {
761
+ "total": round(elapsed_total, 3),
762
+ "gen": round(t_gen1 - t_gen0, 3),
763
+ "under_lock": round(elapsed_lock, 3),
764
+ },
765
+ },
766
+ }