maxime-antoine-dev commited on
Commit
d48c265
Β·
1 Parent(s): 07cd6a9

fixed infinite gen

Browse files
Files changed (1) hide show
  1. main.py +215 -145
main.py CHANGED
@@ -5,8 +5,7 @@ import time
5
  import uuid
6
  import asyncio
7
  import re
8
- from typing import Any, Dict, Optional, List
9
- from functools import lru_cache
10
 
11
  from fastapi import FastAPI
12
  from fastapi.middleware.cors import CORSMiddleware
@@ -40,6 +39,13 @@ LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
40
  # "Light" runtime knobs
41
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
42
 
 
 
 
 
 
 
 
43
  # One request at a time on CPU
44
  GEN_LOCK = asyncio.Lock()
45
 
@@ -68,12 +74,11 @@ app.add_middleware(
68
  # Schemas
69
  # ============================
70
  class GenParams(BaseModel):
71
- # if True => use "light" parameters
72
  light: bool = False
73
- # optional overrides (applied after picking light/normal defaults)
74
  max_new_tokens: Optional[int] = None
75
  temperature: Optional[float] = None
76
  top_p: Optional[float] = None
 
77
 
78
 
79
  class AnalyzeRequest(GenParams):
@@ -108,10 +113,11 @@ ALLOWED_LABELS = [
108
  "miscellaneous",
109
  "intentional",
110
  ]
111
-
112
  LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
113
 
114
- # Stronger /analyze prompt: forces specificity and forbids the "template" sentence
 
 
115
  ANALYZE_PROMPT = f"""You are a fallacy detection assistant.
116
 
117
  You MUST choose labels ONLY from this list (exact string):
@@ -135,29 +141,21 @@ Hard rules:
135
  - Output ONLY JSON. No markdown. No extra text.
136
  - evidence_quotes MUST be verbatim substrings copied from the input text (no paraphrase).
137
  - Keep each evidence quote short (prefer 1–2 sentences; max 240 chars).
138
- - confidence MUST be a real probability between 0.0 and 1.0 (use 2 decimals).
139
- It MUST NOT be always the same across examples. Calibrate it:
140
- * 0.90–1.00: very explicit, unambiguous match, clear cue words.
141
- * 0.70–0.89: strong match but some ambiguity or missing premise.
142
- * 0.40–0.69: plausible but weak/partial evidence.
143
- * 0.10–0.39: very uncertain.
144
- - The rationale MUST be specific to the evidence (2–4 sentences):
145
- Explain (1) what the quote claims, (2) why that matches the fallacy label,
146
- (3) what logical step is invalid or missing.
147
- DO NOT use generic filler. Do NOT reuse stock phrases.
148
- In particular, you MUST NOT output this sentence:
149
  "The input contains fallacious reasoning consistent with the predicted type(s)."
150
- - overall_explanation MUST also be specific (2–5 sentences): summarize the reasoning issues and reference the key cue(s).
151
- - If no fallacy: has_fallacy=false and fallacies=[] and overall_explanation explains briefly why.
 
 
152
 
153
  INPUT:
154
  {{text}}
155
 
156
- OUTPUT:"""
157
 
158
- # /rewrite prompt: returns ONLY a replacement substring for the quote (server does the replacement)
159
- # IMPORTANT: braces are escaped so .format() does not treat the JSON schema as placeholders.
160
- REWRITE_PROMPT = """You are rewriting a small quoted span inside a larger text.
161
 
162
  Goal:
163
  - You MUST propose a replacement for the QUOTE only.
@@ -177,19 +175,22 @@ Hard rules:
177
  - replacement_quote should be standalone text (no surrounding quotes).
178
  - why_this_fix: 1–3 sentences, specific.
179
 
 
 
 
180
  INPUT_TEXT:
181
- {text}
182
 
183
  QUOTE_TO_REWRITE:
184
- {quote}
185
 
186
  FALLACY_TYPE:
187
- {fallacy_type}
188
 
189
  WHY_FALLACIOUS:
190
- {rationale}
191
 
192
- OUTPUT:"""
193
 
194
 
195
  def build_analyze_messages(text: str) -> List[Dict[str, str]]:
@@ -200,11 +201,12 @@ def build_analyze_messages(text: str) -> List[Dict[str, str]]:
200
 
201
 
202
  def build_rewrite_messages(text: str, quote: str, fallacy_type: str, rationale: str) -> List[Dict[str, str]]:
203
- prompt = REWRITE_PROMPT.format(
204
- text=text,
205
- quote=quote,
206
- fallacy_type=fallacy_type,
207
- rationale=rationale,
 
208
  )
209
  return [
210
  {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
@@ -220,8 +222,17 @@ def _log(rid: str, msg: str):
220
 
221
 
222
  # ============================
223
- # Robust JSON extraction
224
  # ============================
 
 
 
 
 
 
 
 
 
225
  def stop_at_complete_json(text: str) -> Optional[str]:
226
  start = text.find("{")
227
  if start == -1:
@@ -255,6 +266,7 @@ def stop_at_complete_json(text: str) -> Optional[str]:
255
 
256
 
257
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
 
258
  cut = stop_at_complete_json(s) or s
259
  start = cut.find("{")
260
  end = cut.rfind("}")
@@ -267,6 +279,90 @@ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
267
  return None
268
 
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # ============================
271
  # Model load
272
  # ============================
@@ -353,6 +449,7 @@ def pick_params(req: GenParams) -> Dict[str, Any]:
353
  "temperature": LIGHT_TEMPERATURE,
354
  "top_p": LIGHT_TOP_P,
355
  "n_batch": LIGHT_N_BATCH,
 
356
  }
357
  else:
358
  params = {
@@ -360,6 +457,7 @@ def pick_params(req: GenParams) -> Dict[str, Any]:
360
  "temperature": TEMPERATURE_DEFAULT,
361
  "top_p": TOP_P_DEFAULT,
362
  "n_batch": N_BATCH,
 
363
  }
364
 
365
  if req.max_new_tokens is not None:
@@ -368,47 +466,34 @@ def pick_params(req: GenParams) -> Dict[str, Any]:
368
  params["temperature"] = float(req.temperature)
369
  if req.top_p is not None:
370
  params["top_p"] = float(req.top_p)
 
 
371
 
372
  # Safety caps
373
  params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 400))
374
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
375
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
376
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
 
377
  return params
378
 
379
 
380
  # ============================
381
  # Post-processing helpers
382
  # ============================
383
- # This exact sentence is a known training artefact that can leak into rationales/overall explanations.
384
- # We strip it server-side for stable outputs.
385
  _TEMPLATE_SENTENCE = "The input contains fallacious reasoning consistent with the predicted type(s)."
386
-
387
- # Match the sentence with minor variations (extra spaces / trailing punctuation), case-insensitive.
388
  _TEMPLATE_RE = re.compile(
389
- r"(?is)\bThe input contains fallacious reasoning consistent with the predicted type\(s\)\.\s*",
390
  )
391
 
392
 
393
  def strip_template_sentence(text: Any) -> str:
394
- """
395
- Remove the known stock sentence from model outputs, then clean up whitespace/punctuation.
396
- Safe to call on non-strings.
397
- """
398
  if not isinstance(text, str):
399
  return ""
400
  out = _TEMPLATE_RE.sub("", text)
401
-
402
- # Also strip any leftover exact substring variant (belt & suspenders)
403
  out = out.replace(_TEMPLATE_SENTENCE, "")
404
-
405
- # Normalize whitespace
406
  out = re.sub(r"\s{2,}", " ", out).strip()
407
-
408
- # Remove leading separators left behind
409
  out = re.sub(r"^[\s\-–—:;,\.\u2022]+", "", out).strip()
410
-
411
- # Fix occasional doubled punctuation
412
  out = out.replace("..", ".").replace(" ,", ",").strip()
413
  return out
414
 
@@ -433,11 +518,6 @@ def _is_allowed_label(lbl: Any) -> bool:
433
 
434
 
435
  def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
436
- """
437
- Enforce shape, clamp confidence, drop invalid labels,
438
- enforce evidence_quotes being substrings.
439
- Also strips known training artefacts from rationales/overall.
440
- """
441
  has_fallacy = bool(obj.get("has_fallacy", False))
442
  fallacies_in = obj.get("fallacies", [])
443
  if not isinstance(fallacies_in, list):
@@ -452,12 +532,12 @@ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, A
452
  continue
453
 
454
  conf = _clamp01(f.get("confidence", 0.5))
455
- # keep 2 decimals for nicer UI
456
  conf = float(f"{conf:.2f}")
457
 
458
  ev = f.get("evidence_quotes", [])
459
  if not isinstance(ev, list):
460
  ev = []
 
461
  ev_clean: List[str] = []
462
  for q in ev:
463
  if not isinstance(q, str):
@@ -465,23 +545,10 @@ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, A
465
  qq = q.strip()
466
  if not qq:
467
  continue
468
- # evidence MUST be substring
469
  if qq in input_text:
470
- # keep short, but don't hard-cut if it breaks substring matching
471
- if len(qq) <= 240:
472
- ev_clean.append(qq)
473
- else:
474
- # if too long, try to keep first 240 if still substring (rare); else keep as-is
475
- short = qq[:240]
476
- if short in input_text:
477
- ev_clean.append(short)
478
- else:
479
- ev_clean.append(qq)
480
-
481
- rationale = f.get("rationale")
482
- if not isinstance(rationale, str):
483
- rationale = ""
484
- rationale = strip_template_sentence(rationale)
485
 
486
  fallacies_out.append(
487
  {
@@ -492,12 +559,8 @@ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, A
492
  }
493
  )
494
 
495
- overall = obj.get("overall_explanation")
496
- if not isinstance(overall, str):
497
- overall = ""
498
- overall = strip_template_sentence(overall)
499
 
500
- # If no fallacies survived sanitation, force no-fallacy state
501
  if len(fallacies_out) == 0:
502
  has_fallacy = False
503
 
@@ -509,30 +572,23 @@ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, A
509
 
510
 
511
  def generate_overall_explanation(clean: Dict[str, Any]) -> str:
512
- """
513
- Build a non-duplicative overall explanation that (a) summarizes what happened and
514
- (b) highlights risks of the detected fallacy(ies).
515
- This intentionally does NOT copy any per-fallacy rationale verbatim.
516
- """
517
  has_fallacy = bool(clean.get("has_fallacy"))
518
  fallacies = clean.get("fallacies") or []
519
- if not isinstance(fallacies, list):
520
- fallacies = []
521
-
522
  if not has_fallacy or not fallacies:
523
  return (
524
  "No clear fallacious reasoning was detected in the text. "
525
- "The argument appears broadly consistent as written, though it may still depend on unstated assumptions."
526
  )
527
 
528
- # Unique types, preserve order
529
  types: List[str] = []
530
  for f in fallacies:
531
- t = f.get("type") if isinstance(f, dict) else None
532
- if isinstance(t, str) and t not in types:
533
- types.append(t)
 
534
 
535
- # Example cue quote (keep very short)
536
  example = ""
537
  for f in fallacies:
538
  if isinstance(f, dict):
@@ -560,7 +616,6 @@ def generate_overall_explanation(clean: Dict[str, Any]) -> str:
560
  "intentional": "It can be persuasive while bypassing careful reasoning, increasing the chance of manipulation.",
561
  }
562
 
563
- # Pick up to 2 risk sentences for the detected types
564
  risks: List[str] = []
565
  for t in types:
566
  rs = risk_map.get(t)
@@ -570,25 +625,45 @@ def generate_overall_explanation(clean: Dict[str, Any]) -> str:
570
  break
571
 
572
  types_str = ", ".join(types) if len(types) <= 3 else ", ".join(types[:3]) + "…"
573
- sentences: List[str] = []
574
- sentences.append(
575
  f"The text contains fallacious reasoning ({types_str}) that can make the conclusion seem stronger than the evidence supports."
576
  )
577
  if example:
578
- sentences.append(f'For example: "{example}".')
579
- if risks:
580
- sentences.append("Risk: " + " ".join(risks))
581
- else:
582
- sentences.append("Risk: it may mislead readers by presenting weak support as if it were decisive.")
583
 
584
- return " ".join(sentences).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
 
587
  # ============================
588
- # Cached generation (task-aware)
589
  # ============================
590
- @lru_cache(maxsize=512)
591
- def _cached_chat_completion(
592
  task: str,
593
  payload: str,
594
  light: bool,
@@ -596,10 +671,16 @@ def _cached_chat_completion(
596
  temperature: float,
597
  top_p: float,
598
  n_batch: int,
 
599
  ) -> Dict[str, Any]:
600
  if llm is None:
601
  return {"ok": False, "error": "model_not_loaded", "detail": load_error}
602
 
 
 
 
 
 
603
  try:
604
  llm.n_batch = int(n_batch) # type: ignore[attr-defined]
605
  except Exception:
@@ -622,20 +703,33 @@ def _cached_chat_completion(
622
  else:
623
  return {"ok": False, "error": "unknown_task"}
624
 
 
625
  out = llm.create_chat_completion(
626
  messages=messages,
627
  max_tokens=int(max_new_tokens),
628
  temperature=float(temperature),
629
  top_p=float(top_p),
 
 
630
  stream=False,
631
  )
 
632
 
633
  raw = out["choices"][0]["message"]["content"]
 
 
634
  obj = extract_first_json_obj(raw)
635
  if obj is None:
636
- return {"ok": False, "error": "json_parse_error", "raw": raw}
 
637
 
638
- return {"ok": True, "result": obj}
 
 
 
 
 
 
639
 
640
 
641
  def _occurrence_index(text: str, sub: str, occurrence: int) -> int:
@@ -680,17 +774,14 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
680
  params = pick_params(req)
681
  _log(
682
  rid,
683
- f"βš™οΈ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
684
  )
685
 
686
  payload = json.dumps({"text": req.text}, ensure_ascii=False)
687
 
688
  async with GEN_LOCK:
689
- t_lock = time.time()
690
-
691
  _log(rid, "🧠 Generating analyze...")
692
- t_gen0 = time.time()
693
- res = _cached_chat_completion(
694
  "analyze",
695
  payload,
696
  bool(req.light),
@@ -698,11 +789,10 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
698
  float(params["temperature"]),
699
  float(params["top_p"]),
700
  int(params["n_batch"]),
 
701
  )
702
- t_gen1 = time.time()
703
 
704
  elapsed_total = time.time() - t0
705
- elapsed_lock = time.time() - t_lock
706
 
707
  if not res.get("ok"):
708
  _log(rid, f"❌ /analyze failed: {res.get('error')}")
@@ -716,14 +806,14 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
716
  "temperature": float(params["temperature"]),
717
  "top_p": float(params["top_p"]),
718
  "n_batch": int(params["n_batch"]),
 
719
  },
720
- "timings_s": {"total": round(elapsed_total, 3), "gen": round(t_gen1 - t_gen0, 3)},
721
  },
722
  }
723
 
724
- # sanitize output for stability (substrings, labels, confidence clamp) + strip training artefact
725
  clean = sanitize_analyze_output(res["result"], req.text)
726
- # overwrite overall explanation with a real summary + risk (and never copy rationales)
727
  clean["overall_explanation"] = generate_overall_explanation(clean)
728
 
729
  _log(rid, f"βœ… /analyze ok fallacies={len(clean.get('fallacies', []))} total={elapsed_total:.2f}s")
@@ -738,12 +828,9 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
738
  "temperature": float(params["temperature"]),
739
  "top_p": float(params["top_p"]),
740
  "n_batch": int(params["n_batch"]),
 
741
  },
742
- "timings_s": {
743
- "total": round(elapsed_total, 3),
744
- "gen": round(t_gen1 - t_gen0, 3),
745
- "under_lock": round(elapsed_lock, 3),
746
- },
747
  },
748
  }
749
 
@@ -766,21 +853,13 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
766
  quote = req.quote.strip()
767
  occurrence = int(req.occurrence or 0)
768
 
769
- # validate quote existence early
770
  if _occurrence_index(req.text, quote, occurrence) == -1:
771
  return {"ok": False, "error": "quote_not_found", "detail": {"occurrence": occurrence}}
772
 
773
  params = pick_params(req)
774
- # rewrite generally needs a bit more room than light analyze if you want fluent replacements
775
- # (still controllable by request overrides)
776
  if req.light and req.max_new_tokens is None:
777
  params["max_new_tokens"] = max(params["max_new_tokens"], 80)
778
 
779
- _log(
780
- rid,
781
- f"βš™οΈ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
782
- )
783
-
784
  payload = json.dumps(
785
  {
786
  "text": req.text,
@@ -792,11 +871,8 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
792
  )
793
 
794
  async with GEN_LOCK:
795
- t_lock = time.time()
796
-
797
  _log(rid, "🧠 Generating rewrite replacement_quote...")
798
- t_gen0 = time.time()
799
- res = _cached_chat_completion(
800
  "rewrite",
801
  payload,
802
  bool(req.light),
@@ -804,11 +880,10 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
804
  float(params["temperature"]),
805
  float(params["top_p"]),
806
  int(params["n_batch"]),
 
807
  )
808
- t_gen1 = time.time()
809
 
810
  elapsed_total = time.time() - t0
811
- elapsed_lock = time.time() - t_lock
812
 
813
  if not res.get("ok"):
814
  _log(rid, f"❌ /rewrite failed: {res.get('error')}")
@@ -822,8 +897,9 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
822
  "temperature": float(params["temperature"]),
823
  "top_p": float(params["top_p"]),
824
  "n_batch": int(params["n_batch"]),
 
825
  },
826
- "timings_s": {"total": round(elapsed_total, 3), "gen": round(t_gen1 - t_gen0, 3)},
827
  },
828
  }
829
 
@@ -840,11 +916,8 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
840
  return {"ok": False, "error": "empty_replacement_quote", "raw": obj}
841
 
842
  why = obj.get("why_this_fix")
843
- if not isinstance(why, str):
844
- why = ""
845
- why = why.strip()
846
 
847
- # server-side enforced: ONLY the quote is changed
848
  rep = _replace_nth(req.text, quote, replacement, occurrence)
849
  if not rep.get("ok"):
850
  return {"ok": False, "error": rep.get("error", "replace_failed")}
@@ -873,11 +946,8 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
873
  "temperature": float(params["temperature"]),
874
  "top_p": float(params["top_p"]),
875
  "n_batch": int(params["n_batch"]),
 
876
  },
877
- "timings_s": {
878
- "total": round(elapsed_total, 3),
879
- "gen": round(t_gen1 - t_gen0, 3),
880
- "under_lock": round(elapsed_lock, 3),
881
- },
882
  },
883
  }
 
5
  import uuid
6
  import asyncio
7
  import re
8
+ from typing import Any, Dict, Optional, List, Tuple
 
9
 
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
 
39
  # "Light" runtime knobs
40
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
41
 
42
+ # Anti-loop defaults
43
+ REPEAT_PENALTY_DEFAULT = float(os.getenv("REPEAT_PENALTY", "1.15"))
44
+
45
+ # Cache only SUCCESSFUL generations (TTL)
46
+ CACHE_TTL_S = int(os.getenv("CACHE_TTL_S", "300")) # 5 minutes
47
+ CACHE_MAX_ITEMS = int(os.getenv("CACHE_MAX_ITEMS", "512"))
48
+
49
  # One request at a time on CPU
50
  GEN_LOCK = asyncio.Lock()
51
 
 
74
  # Schemas
75
  # ============================
76
  class GenParams(BaseModel):
 
77
  light: bool = False
 
78
  max_new_tokens: Optional[int] = None
79
  temperature: Optional[float] = None
80
  top_p: Optional[float] = None
81
+ repeat_penalty: Optional[float] = None
82
 
83
 
84
  class AnalyzeRequest(GenParams):
 
113
  "miscellaneous",
114
  "intentional",
115
  ]
 
116
  LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
117
 
118
+ END_SENTINEL = "<END_JSON>"
119
+ STOP_SEQS = [END_SENTINEL]
120
+
121
  ANALYZE_PROMPT = f"""You are a fallacy detection assistant.
122
 
123
  You MUST choose labels ONLY from this list (exact string):
 
141
  - Output ONLY JSON. No markdown. No extra text.
142
  - evidence_quotes MUST be verbatim substrings copied from the input text (no paraphrase).
143
  - Keep each evidence quote short (prefer 1–2 sentences; max 240 chars).
144
+ - confidence MUST be a real probability between 0.0 and 1.0 (use 2 decimals). It MUST NOT be always the same.
145
+ - The rationale MUST be specific (2–4 sentences). DO NOT use generic filler.
146
+ - You MUST NOT output this sentence anywhere:
 
 
 
 
 
 
 
 
147
  "The input contains fallacious reasoning consistent with the predicted type(s)."
148
+ - overall_explanation MUST be specific (2–5 sentences).
149
+
150
+ IMPORTANT TERMINATION:
151
+ - After the JSON object, output the token {END_SENTINEL} and stop.
152
 
153
  INPUT:
154
  {{text}}
155
 
156
+ OUTPUT (JSON then {END_SENTINEL}):"""
157
 
158
+ REWRITE_PROMPT = f"""You are rewriting a small quoted span inside a larger text.
 
 
159
 
160
  Goal:
161
  - You MUST propose a replacement for the QUOTE only.
 
175
  - replacement_quote should be standalone text (no surrounding quotes).
176
  - why_this_fix: 1–3 sentences, specific.
177
 
178
+ IMPORTANT TERMINATION:
179
+ - After the JSON object, output the token {END_SENTINEL} and stop.
180
+
181
  INPUT_TEXT:
182
+ {{text}}
183
 
184
  QUOTE_TO_REWRITE:
185
+ {{quote}}
186
 
187
  FALLACY_TYPE:
188
+ {{fallacy_type}}
189
 
190
  WHY_FALLACIOUS:
191
+ {{rationale}}
192
 
193
+ OUTPUT (JSON then {END_SENTINEL}):"""
194
 
195
 
196
  def build_analyze_messages(text: str) -> List[Dict[str, str]]:
 
201
 
202
 
203
  def build_rewrite_messages(text: str, quote: str, fallacy_type: str, rationale: str) -> List[Dict[str, str]]:
204
+ prompt = (
205
+ REWRITE_PROMPT
206
+ .replace("{text}", text)
207
+ .replace("{quote}", quote)
208
+ .replace("{fallacy_type}", fallacy_type)
209
+ .replace("{rationale}", rationale)
210
  )
211
  return [
212
  {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
 
222
 
223
 
224
  # ============================
225
+ # Robust JSON extraction + repair
226
  # ============================
227
+ def _strip_sentinel(s: str) -> str:
228
+ if not isinstance(s, str):
229
+ return ""
230
+ idx = s.find(END_SENTINEL)
231
+ if idx != -1:
232
+ return s[:idx]
233
+ return s
234
+
235
+
236
  def stop_at_complete_json(text: str) -> Optional[str]:
237
  start = text.find("{")
238
  if start == -1:
 
266
 
267
 
268
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
269
+ s = _strip_sentinel(s)
270
  cut = stop_at_complete_json(s) or s
271
  start = cut.find("{")
272
  end = cut.rfind("}")
 
279
  return None
280
 
281
 
282
+ def _count_unescaped_quotes(s: str) -> int:
283
+ in_str = False
284
+ esc = False
285
+ count = 0
286
+ for ch in s:
287
+ if esc:
288
+ esc = False
289
+ continue
290
+ if ch == "\\":
291
+ esc = True
292
+ continue
293
+ if ch == '"':
294
+ count += 1
295
+ in_str = not in_str
296
+ return count
297
+
298
+
299
+ def _balance_braces_outside_strings(s: str) -> Tuple[int, int]:
300
+ opens = 0
301
+ closes = 0
302
+ in_str = False
303
+ esc = False
304
+ for ch in s:
305
+ if in_str:
306
+ if esc:
307
+ esc = False
308
+ elif ch == "\\":
309
+ esc = True
310
+ elif ch == '"':
311
+ in_str = False
312
+ continue
313
+ else:
314
+ if ch == '"':
315
+ in_str = True
316
+ continue
317
+ if ch == "{":
318
+ opens += 1
319
+ elif ch == "}":
320
+ closes += 1
321
+ return opens, closes
322
+
323
+
324
+ def try_repair_and_parse_json(raw: str) -> Optional[Dict[str, Any]]:
325
+ """
326
+ Best-effort repair when model got stuck/repetitive and didn't close JSON.
327
+ Strategy:
328
+ - take from first '{'
329
+ - if quotes count odd => append '"'
330
+ - balance braces outside strings by appending missing '}'
331
+ - try json.loads
332
+ """
333
+ if not isinstance(raw, str):
334
+ return None
335
+ s = _strip_sentinel(raw)
336
+ start = s.find("{")
337
+ if start == -1:
338
+ return None
339
+ cand = s[start:].strip()
340
+
341
+ # If it contains huge repetition, hard-trim after some chars to avoid pathological payloads.
342
+ # (Keeps server responsive.)
343
+ MAX_CAND = 50_000
344
+ if len(cand) > MAX_CAND:
345
+ cand = cand[:MAX_CAND]
346
+
347
+ # Close open string if needed
348
+ if _count_unescaped_quotes(cand) % 2 == 1:
349
+ cand += '"'
350
+
351
+ opens, closes = _balance_braces_outside_strings(cand)
352
+ if closes > opens:
353
+ # can't safely repair this
354
+ return None
355
+ if opens > closes:
356
+ cand += "}" * (opens - closes)
357
+
358
+ cand = cand.strip()
359
+
360
+ try:
361
+ return json.loads(cand)
362
+ except Exception:
363
+ return None
364
+
365
+
366
  # ============================
367
  # Model load
368
  # ============================
 
449
  "temperature": LIGHT_TEMPERATURE,
450
  "top_p": LIGHT_TOP_P,
451
  "n_batch": LIGHT_N_BATCH,
452
+ "repeat_penalty": REPEAT_PENALTY_DEFAULT,
453
  }
454
  else:
455
  params = {
 
457
  "temperature": TEMPERATURE_DEFAULT,
458
  "top_p": TOP_P_DEFAULT,
459
  "n_batch": N_BATCH,
460
+ "repeat_penalty": REPEAT_PENALTY_DEFAULT,
461
  }
462
 
463
  if req.max_new_tokens is not None:
 
466
  params["temperature"] = float(req.temperature)
467
  if req.top_p is not None:
468
  params["top_p"] = float(req.top_p)
469
+ if req.repeat_penalty is not None:
470
+ params["repeat_penalty"] = float(req.repeat_penalty)
471
 
472
  # Safety caps
473
  params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 400))
474
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
475
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
476
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
477
+ params["repeat_penalty"] = max(1.0, min(float(params["repeat_penalty"]), 1.5))
478
  return params
479
 
480
 
481
  # ============================
482
  # Post-processing helpers
483
  # ============================
 
 
484
  _TEMPLATE_SENTENCE = "The input contains fallacious reasoning consistent with the predicted type(s)."
 
 
485
  _TEMPLATE_RE = re.compile(
486
+ r"(?is)\bThe input contains fallacious reasoning consistent with the predicted type\(s\)\.\s*"
487
  )
488
 
489
 
490
  def strip_template_sentence(text: Any) -> str:
 
 
 
 
491
  if not isinstance(text, str):
492
  return ""
493
  out = _TEMPLATE_RE.sub("", text)
 
 
494
  out = out.replace(_TEMPLATE_SENTENCE, "")
 
 
495
  out = re.sub(r"\s{2,}", " ", out).strip()
 
 
496
  out = re.sub(r"^[\s\-–—:;,\.\u2022]+", "", out).strip()
 
 
497
  out = out.replace("..", ".").replace(" ,", ",").strip()
498
  return out
499
 
 
518
 
519
 
520
  def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
 
 
 
 
 
521
  has_fallacy = bool(obj.get("has_fallacy", False))
522
  fallacies_in = obj.get("fallacies", [])
523
  if not isinstance(fallacies_in, list):
 
532
  continue
533
 
534
  conf = _clamp01(f.get("confidence", 0.5))
 
535
  conf = float(f"{conf:.2f}")
536
 
537
  ev = f.get("evidence_quotes", [])
538
  if not isinstance(ev, list):
539
  ev = []
540
+
541
  ev_clean: List[str] = []
542
  for q in ev:
543
  if not isinstance(q, str):
 
545
  qq = q.strip()
546
  if not qq:
547
  continue
 
548
  if qq in input_text:
549
+ ev_clean.append(qq if len(qq) <= 240 else qq[:240])
550
+
551
+ rationale = strip_template_sentence(f.get("rationale", ""))
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  fallacies_out.append(
554
  {
 
559
  }
560
  )
561
 
562
+ overall = strip_template_sentence(obj.get("overall_explanation", ""))
 
 
 
563
 
 
564
  if len(fallacies_out) == 0:
565
  has_fallacy = False
566
 
 
572
 
573
 
574
  def generate_overall_explanation(clean: Dict[str, Any]) -> str:
 
 
 
 
 
575
  has_fallacy = bool(clean.get("has_fallacy"))
576
  fallacies = clean.get("fallacies") or []
 
 
 
577
  if not has_fallacy or not fallacies:
578
  return (
579
  "No clear fallacious reasoning was detected in the text. "
580
+ "The argument appears broadly consistent as written, though it may still rely on unstated assumptions."
581
  )
582
 
583
+ # unique types
584
  types: List[str] = []
585
  for f in fallacies:
586
+ if isinstance(f, dict):
587
+ t = f.get("type")
588
+ if isinstance(t, str) and t not in types:
589
+ types.append(t)
590
 
591
+ # example
592
  example = ""
593
  for f in fallacies:
594
  if isinstance(f, dict):
 
616
  "intentional": "It can be persuasive while bypassing careful reasoning, increasing the chance of manipulation.",
617
  }
618
 
 
619
  risks: List[str] = []
620
  for t in types:
621
  rs = risk_map.get(t)
 
625
  break
626
 
627
  types_str = ", ".join(types) if len(types) <= 3 else ", ".join(types[:3]) + "…"
628
+ out = (
 
629
  f"The text contains fallacious reasoning ({types_str}) that can make the conclusion seem stronger than the evidence supports."
630
  )
631
  if example:
632
+ out += f' For example: "{example}".'
633
+ out += " Risk: " + (" ".join(risks) if risks else "it may mislead readers by presenting weak support as if it were decisive.")
634
+ return out.strip()
 
 
635
 
636
+
637
+ # ============================
638
+ # Success-only cache
639
+ # ============================
640
+ _SUCCESS_CACHE: Dict[Tuple[Any, ...], Tuple[float, Dict[str, Any]]] = {}
641
+
642
+
643
+ def _cache_get(key: Tuple[Any, ...]) -> Optional[Dict[str, Any]]:
644
+ item = _SUCCESS_CACHE.get(key)
645
+ if not item:
646
+ return None
647
+ ts, val = item
648
+ if (time.time() - ts) > CACHE_TTL_S:
649
+ _SUCCESS_CACHE.pop(key, None)
650
+ return None
651
+ return val
652
+
653
+
654
+ def _cache_put(key: Tuple[Any, ...], val: Dict[str, Any]) -> None:
655
+ # naive eviction if too big
656
+ if len(_SUCCESS_CACHE) >= CACHE_MAX_ITEMS:
657
+ # drop oldest
658
+ oldest_key = min(_SUCCESS_CACHE.items(), key=lambda kv: kv[1][0])[0]
659
+ _SUCCESS_CACHE.pop(oldest_key, None)
660
+ _SUCCESS_CACHE[key] = (time.time(), val)
661
 
662
 
663
  # ============================
664
+ # Completion (task-aware)
665
  # ============================
666
+ def _chat_completion(
 
667
  task: str,
668
  payload: str,
669
  light: bool,
 
671
  temperature: float,
672
  top_p: float,
673
  n_batch: int,
674
+ repeat_penalty: float,
675
  ) -> Dict[str, Any]:
676
  if llm is None:
677
  return {"ok": False, "error": "model_not_loaded", "detail": load_error}
678
 
679
+ key = (task, payload, light, max_new_tokens, temperature, top_p, n_batch, repeat_penalty)
680
+ cached = _cache_get(key)
681
+ if cached is not None:
682
+ return {"ok": True, "result": cached, "cached": True}
683
+
684
  try:
685
  llm.n_batch = int(n_batch) # type: ignore[attr-defined]
686
  except Exception:
 
703
  else:
704
  return {"ok": False, "error": "unknown_task"}
705
 
706
+ t0 = time.time()
707
  out = llm.create_chat_completion(
708
  messages=messages,
709
  max_tokens=int(max_new_tokens),
710
  temperature=float(temperature),
711
  top_p=float(top_p),
712
+ repeat_penalty=float(repeat_penalty),
713
+ stop=STOP_SEQS,
714
  stream=False,
715
  )
716
+ t1 = time.time()
717
 
718
  raw = out["choices"][0]["message"]["content"]
719
+ raw = _strip_sentinel(raw)
720
+
721
  obj = extract_first_json_obj(raw)
722
  if obj is None:
723
+ # attempt repair (close quote/braces) to avoid unusable responses
724
+ obj = try_repair_and_parse_json(raw)
725
 
726
+ if obj is None:
727
+ return {"ok": False, "error": "json_parse_error", "raw": raw, "gen_s": round(t1 - t0, 3)}
728
+
729
+ # success only: store in cache
730
+ _cache_put(key, obj)
731
+
732
+ return {"ok": True, "result": obj, "gen_s": round(t1 - t0, 3)}
733
 
734
 
735
  def _occurrence_index(text: str, sub: str, occurrence: int) -> int:
 
774
  params = pick_params(req)
775
  _log(
776
  rid,
777
+ f"βš™οΈ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']} repeat_penalty={params['repeat_penalty']}",
778
  )
779
 
780
  payload = json.dumps({"text": req.text}, ensure_ascii=False)
781
 
782
  async with GEN_LOCK:
 
 
783
  _log(rid, "🧠 Generating analyze...")
784
+ res = _chat_completion(
 
785
  "analyze",
786
  payload,
787
  bool(req.light),
 
789
  float(params["temperature"]),
790
  float(params["top_p"]),
791
  int(params["n_batch"]),
792
+ float(params["repeat_penalty"]),
793
  )
 
794
 
795
  elapsed_total = time.time() - t0
 
796
 
797
  if not res.get("ok"):
798
  _log(rid, f"❌ /analyze failed: {res.get('error')}")
 
806
  "temperature": float(params["temperature"]),
807
  "top_p": float(params["top_p"]),
808
  "n_batch": int(params["n_batch"]),
809
+ "repeat_penalty": float(params["repeat_penalty"]),
810
  },
811
+ "timings_s": {"total": round(elapsed_total, 3), "gen": res.get("gen_s", None)},
812
  },
813
  }
814
 
 
815
  clean = sanitize_analyze_output(res["result"], req.text)
816
+ # ensure overall explanation is always a useful summary + risk
817
  clean["overall_explanation"] = generate_overall_explanation(clean)
818
 
819
  _log(rid, f"βœ… /analyze ok fallacies={len(clean.get('fallacies', []))} total={elapsed_total:.2f}s")
 
828
  "temperature": float(params["temperature"]),
829
  "top_p": float(params["top_p"]),
830
  "n_batch": int(params["n_batch"]),
831
+ "repeat_penalty": float(params["repeat_penalty"]),
832
  },
833
+ "timings_s": {"total": round(elapsed_total, 3), "gen": res.get("gen_s", None)},
 
 
 
 
834
  },
835
  }
836
 
 
853
  quote = req.quote.strip()
854
  occurrence = int(req.occurrence or 0)
855
 
 
856
  if _occurrence_index(req.text, quote, occurrence) == -1:
857
  return {"ok": False, "error": "quote_not_found", "detail": {"occurrence": occurrence}}
858
 
859
  params = pick_params(req)
 
 
860
  if req.light and req.max_new_tokens is None:
861
  params["max_new_tokens"] = max(params["max_new_tokens"], 80)
862
 
 
 
 
 
 
863
  payload = json.dumps(
864
  {
865
  "text": req.text,
 
871
  )
872
 
873
  async with GEN_LOCK:
 
 
874
  _log(rid, "🧠 Generating rewrite replacement_quote...")
875
+ res = _chat_completion(
 
876
  "rewrite",
877
  payload,
878
  bool(req.light),
 
880
  float(params["temperature"]),
881
  float(params["top_p"]),
882
  int(params["n_batch"]),
883
+ float(params["repeat_penalty"]),
884
  )
 
885
 
886
  elapsed_total = time.time() - t0
 
887
 
888
  if not res.get("ok"):
889
  _log(rid, f"❌ /rewrite failed: {res.get('error')}")
 
897
  "temperature": float(params["temperature"]),
898
  "top_p": float(params["top_p"]),
899
  "n_batch": int(params["n_batch"]),
900
+ "repeat_penalty": float(params["repeat_penalty"]),
901
  },
902
+ "timings_s": {"total": round(elapsed_total, 3), "gen": res.get("gen_s", None)},
903
  },
904
  }
905
 
 
916
  return {"ok": False, "error": "empty_replacement_quote", "raw": obj}
917
 
918
  why = obj.get("why_this_fix")
919
+ why = strip_template_sentence(why)
 
 
920
 
 
921
  rep = _replace_nth(req.text, quote, replacement, occurrence)
922
  if not rep.get("ok"):
923
  return {"ok": False, "error": rep.get("error", "replace_failed")}
 
946
  "temperature": float(params["temperature"]),
947
  "top_p": float(params["top_p"]),
948
  "n_batch": int(params["n_batch"]),
949
+ "repeat_penalty": float(params["repeat_penalty"]),
950
  },
951
+ "timings_s": {"total": round(elapsed_total, 3), "gen": res.get("gen_s", None)},
 
 
 
 
952
  },
953
  }