maxime-antoine-dev commited on
Commit
81e2856
·
1 Parent(s): afd3da3
Files changed (5) hide show
  1. logger_utils.py +0 -29
  2. main.py +538 -99
  3. model_runtime.py +0 -129
  4. prompts.py +0 -113
  5. utils.py +0 -171
logger_utils.py DELETED
@@ -1,29 +0,0 @@
1
- import time
2
- from contextlib import contextmanager
3
-
4
- def log(rid: str, msg: str) -> None:
5
- print(f"[{rid}] {msg}", flush=True)
6
-
7
- class StepLogger:
8
- """
9
- Lightweight structured step logger for server logs.
10
- """
11
- def __init__(self, rid: str, route: str):
12
- self.rid = rid
13
- self.route = route
14
-
15
- def info(self, message: str) -> None:
16
- log(self.rid, f"{self.route} {message}")
17
-
18
- @contextmanager
19
- def step(self, name: str):
20
- t0 = time.time()
21
- self.info(f"step={name} start")
22
- try:
23
- yield
24
- dt = time.time() - t0
25
- self.info(f"step={name} ok ({dt:.3f}s)")
26
- except Exception as e:
27
- dt = time.time() - t0
28
- self.info(f"step={name} fail ({dt:.3f}s) err={repr(e)}")
29
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,17 +1,17 @@
 
1
  import os
2
  import json
3
  import time
4
  import uuid
5
  import asyncio
6
- from typing import Any, Dict, Optional
 
7
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel, Field
11
-
12
- from logger_utils import StepLogger
13
- from utils import sanitize_analyze_output, occurrence_index, replace_nth, strip_template_sentence
14
- from model_runtime import load_llama, get_health, cached_chat_completion
15
 
16
 
17
  # ============================
@@ -20,27 +20,33 @@ from model_runtime import load_llama, get_health, cached_chat_completion
20
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
21
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
22
 
 
23
  N_CTX = int(os.getenv("N_CTX", "1536"))
24
  CPU_COUNT = os.cpu_count() or 4
25
  N_THREADS = int(os.getenv("N_THREADS", str(min(8, max(1, CPU_COUNT - 1)))))
26
  N_BATCH = int(os.getenv("N_BATCH", "256"))
27
 
28
- MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "300"))
 
29
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
30
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
31
 
 
32
  LIGHT_MAX_NEW_TOKENS = int(os.getenv("LIGHT_MAX_NEW_TOKENS", "60"))
33
  LIGHT_TEMPERATURE = float(os.getenv("LIGHT_TEMPERATURE", "0.0"))
34
  LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
 
 
35
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
36
 
 
37
  GEN_LOCK = asyncio.Lock()
38
 
39
  app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
40
 
41
 
42
  # ============================
43
- # CORS
44
  # ============================
45
  _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
46
  if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
@@ -61,14 +67,18 @@ app.add_middleware(
61
  # Schemas
62
  # ============================
63
  class GenParams(BaseModel):
 
64
  light: bool = False
 
65
  max_new_tokens: Optional[int] = None
66
  temperature: Optional[float] = None
67
  top_p: Optional[float] = None
68
 
 
69
  class AnalyzeRequest(GenParams):
70
  text: str
71
 
 
72
  class RewriteRequest(GenParams):
73
  text: str
74
  quote: str = Field(..., description="Verbatim substring that must be replaced.")
@@ -78,17 +88,236 @@ class RewriteRequest(GenParams):
78
 
79
 
80
  # ============================
81
- # Startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  @app.on_event("startup")
84
  def _startup():
85
- load_llama(
86
- gguf_repo_id=GGUF_REPO_ID,
87
- gguf_filename=GGUF_FILENAME,
88
- n_ctx=N_CTX,
89
- n_threads=N_THREADS,
90
- n_batch=N_BATCH,
91
- )
92
 
93
 
94
  @app.get("/")
@@ -98,17 +327,22 @@ def root():
98
 
99
  @app.get("/health")
100
  def health():
101
- return get_health(
102
- gguf_repo_id=GGUF_REPO_ID,
103
- gguf_filename=GGUF_FILENAME,
104
- n_ctx=N_CTX,
105
- n_threads=N_THREADS,
106
- n_batch=N_BATCH,
107
- )
 
 
 
 
 
108
 
109
 
110
  # ============================
111
- # Params selection
112
  # ============================
113
  def pick_params(req: GenParams) -> Dict[str, Any]:
114
  if req.light:
@@ -133,12 +367,209 @@ def pick_params(req: GenParams) -> Dict[str, Any]:
133
  if req.top_p is not None:
134
  params["top_p"] = float(req.top_p)
135
 
 
136
  params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 400))
137
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
138
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
139
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
140
  return params
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # ============================
144
  # Routes
@@ -146,41 +577,42 @@ def pick_params(req: GenParams) -> Dict[str, Any]:
146
  @app.post("/analyze")
147
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
148
  rid = uuid.uuid4().hex[:10]
149
- L = StepLogger(rid, "/analyze")
150
  t0 = time.time()
151
 
152
- L.info(f"received light={req.light} chars={len(req.text) if req.text else 0}")
153
 
154
- with L.step("validate"):
155
- if not req.text or not req.text.strip():
156
- return {"ok": False, "error": "empty_text"}
157
 
158
- with L.step("pick_params"):
159
- params = pick_params(req)
 
 
 
160
 
161
  payload = json.dumps({"text": req.text}, ensure_ascii=False)
162
 
163
- with L.step("generate_under_lock"):
164
- async with GEN_LOCK:
165
- t_lock = time.time()
166
- t_gen0 = time.time()
167
-
168
- res = cached_chat_completion(
169
- "analyze",
170
- payload,
171
- int(params["max_new_tokens"]),
172
- float(params["temperature"]),
173
- float(params["top_p"]),
174
- int(params["n_batch"]),
175
- )
176
-
177
- t_gen1 = time.time()
178
- elapsed_lock = time.time() - t_lock
179
 
180
  elapsed_total = time.time() - t0
 
181
 
182
  if not res.get("ok"):
183
- L.info(f"failed err={res.get('error')}")
184
  return {
185
  **res,
186
  "meta": {
@@ -196,10 +628,10 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
196
  },
197
  }
198
 
199
- with L.step("sanitize"):
200
- clean = sanitize_analyze_output(res["result"], req.text)
201
 
202
- L.info(f"ok fallacies={len(clean.get('fallacies', []))} total={elapsed_total:.2f}s")
203
  return {
204
  "ok": True,
205
  "result": clean,
@@ -224,30 +656,35 @@ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
224
  @app.post("/rewrite")
225
  async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
226
  rid = uuid.uuid4().hex[:10]
227
- L = StepLogger(rid, "/rewrite")
228
  t0 = time.time()
229
 
230
- L.info(
231
- f"received light={req.light} text_chars={len(req.text) if req.text else 0} quote_chars={len(req.quote) if req.quote else 0}"
 
232
  )
233
 
234
- with L.step("validate"):
235
- if not req.text or not req.text.strip():
236
- return {"ok": False, "error": "empty_text"}
237
- if not req.quote or not req.quote.strip():
238
- return {"ok": False, "error": "empty_quote"}
239
 
240
  quote = req.quote.strip()
241
  occurrence = int(req.occurrence or 0)
242
 
243
- with L.step("quote_check"):
244
- if occurrence_index(req.text, quote, occurrence) == -1:
245
- return {"ok": False, "error": "quote_not_found", "detail": {"occurrence": occurrence}}
246
 
247
- with L.step("pick_params"):
248
- params = pick_params(req)
249
- if req.light and req.max_new_tokens is None:
250
- params["max_new_tokens"] = max(params["max_new_tokens"], 80)
 
 
 
 
 
 
251
 
252
  payload = json.dumps(
253
  {
@@ -259,27 +696,27 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
259
  ensure_ascii=False,
260
  )
261
 
262
- with L.step("generate_under_lock"):
263
- async with GEN_LOCK:
264
- t_lock = time.time()
265
- t_gen0 = time.time()
266
-
267
- res = cached_chat_completion(
268
- "rewrite",
269
- payload,
270
- int(params["max_new_tokens"]),
271
- float(params["temperature"]),
272
- float(params["top_p"]),
273
- int(params["n_batch"]),
274
- )
275
-
276
- t_gen1 = time.time()
277
- elapsed_lock = time.time() - t_lock
278
 
279
  elapsed_total = time.time() - t0
 
280
 
281
  if not res.get("ok"):
282
- L.info(f"failed err={res.get('error')}")
283
  return {
284
  **res,
285
  "meta": {
@@ -295,27 +732,29 @@ async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
295
  },
296
  }
297
 
298
- with L.step("validate_model_output"):
299
- obj = res["result"]
300
- if not isinstance(obj, dict):
301
- return {"ok": False, "error": "bad_rewrite_output"}
 
 
 
302
 
303
- replacement = obj.get("replacement_quote")
304
- if not isinstance(replacement, str):
305
- return {"ok": False, "error": "missing_replacement_quote", "raw": obj}
306
- replacement = replacement.strip()
307
- if not replacement:
308
- return {"ok": False, "error": "empty_replacement_quote", "raw": obj}
309
 
310
- why = obj.get("why_this_fix", "")
311
- why = strip_template_sentence(str(why).strip())
 
 
312
 
313
- with L.step("replace"):
314
- rep = replace_nth(req.text, quote, replacement, occurrence)
315
- if not rep.get("ok"):
316
- return {"ok": False, "error": rep.get("error", "replace_failed")}
317
 
318
- L.info(f"ok total={elapsed_total:.2f}s")
319
  return {
320
  "ok": True,
321
  "result": {
 
1
+ # main.py
2
  import os
3
  import json
4
  import time
5
  import uuid
6
  import asyncio
7
+ from typing import Any, Dict, Optional, List
8
+ from functools import lru_cache
9
 
10
  from fastapi import FastAPI
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel, Field
13
+ from huggingface_hub import hf_hub_download
14
+ from llama_cpp import Llama
 
 
15
 
16
 
17
  # ============================
 
20
  GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
21
  GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
22
 
23
+ # Model load params (fixed once at startup)
24
  N_CTX = int(os.getenv("N_CTX", "1536"))
25
  CPU_COUNT = os.cpu_count() or 4
26
  N_THREADS = int(os.getenv("N_THREADS", str(min(8, max(1, CPU_COUNT - 1)))))
27
  N_BATCH = int(os.getenv("N_BATCH", "256"))
28
 
29
+ # Default generation params ("normal")
30
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "180"))
31
  TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
32
  TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
33
 
34
+ # "Light" generation params
35
  LIGHT_MAX_NEW_TOKENS = int(os.getenv("LIGHT_MAX_NEW_TOKENS", "60"))
36
  LIGHT_TEMPERATURE = float(os.getenv("LIGHT_TEMPERATURE", "0.0"))
37
  LIGHT_TOP_P = float(os.getenv("LIGHT_TOP_P", "0.9"))
38
+
39
+ # "Light" runtime knobs
40
  LIGHT_N_BATCH = int(os.getenv("LIGHT_N_BATCH", "64"))
41
 
42
+ # One request at a time on CPU
43
  GEN_LOCK = asyncio.Lock()
44
 
45
  app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
46
 
47
 
48
  # ============================
49
+ # CORS (for browser front-ends)
50
  # ============================
51
  _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
52
  if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
 
67
  # Schemas
68
  # ============================
69
  class GenParams(BaseModel):
70
+ # if True => use "light" parameters
71
  light: bool = False
72
+ # optional overrides (applied after picking light/normal defaults)
73
  max_new_tokens: Optional[int] = None
74
  temperature: Optional[float] = None
75
  top_p: Optional[float] = None
76
 
77
+
78
  class AnalyzeRequest(GenParams):
79
  text: str
80
 
81
+
82
  class RewriteRequest(GenParams):
83
  text: str
84
  quote: str = Field(..., description="Verbatim substring that must be replaced.")
 
88
 
89
 
90
  # ============================
91
+ # Labels & Prompts
92
+ # ============================
93
+ ALLOWED_LABELS = [
94
+ "none",
95
+ "faulty generalization",
96
+ "false causality",
97
+ "circular reasoning",
98
+ "ad populum",
99
+ "ad hominem",
100
+ "fallacy of logic",
101
+ "appeal to emotion",
102
+ "false dilemma",
103
+ "equivocation",
104
+ "fallacy of extension",
105
+ "fallacy of relevance",
106
+ "fallacy of credibility",
107
+ "miscellaneous",
108
+ "intentional",
109
+ ]
110
+
111
+ LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
112
+
113
+ # Stronger /analyze prompt: forces specificity and forbids the "template" sentence
114
+ ANALYZE_PROMPT = f"""You are a fallacy detection assistant.
115
+
116
+ You MUST choose labels ONLY from this list (exact string):
117
+ {LABELS_STR}
118
+
119
+ You MUST return ONLY valid JSON with this schema:
120
+ {{
121
+ "has_fallacy": boolean,
122
+ "fallacies": [
123
+ {{
124
+ "type": string,
125
+ "confidence": number,
126
+ "evidence_quotes": [string],
127
+ "rationale": string
128
+ }}
129
+ ],
130
+ "overall_explanation": string
131
+ }}
132
+
133
+ Hard rules:
134
+ - Output ONLY JSON. No markdown. No extra text.
135
+ - evidence_quotes MUST be verbatim substrings copied from the input text (no paraphrase).
136
+ - Keep each evidence quote short (prefer 1–2 sentences; max 240 chars).
137
+ - confidence MUST be a real probability between 0.0 and 1.0 (use 2 decimals).
138
+ It MUST NOT be always the same across examples. Calibrate it:
139
+ * 0.90–1.00: very explicit, unambiguous match, clear cue words.
140
+ * 0.70–0.89: strong match but some ambiguity or missing premise.
141
+ * 0.40–0.69: plausible but weak/partial evidence.
142
+ * 0.10–0.39: very uncertain.
143
+ - The rationale MUST be specific to the evidence (2–4 sentences):
144
+ Explain (1) what the quote claims, (2) why that matches the fallacy label,
145
+ (3) what logical step is invalid or missing.
146
+ DO NOT use generic filler. Do NOT reuse stock phrases.
147
+ In particular, you MUST NOT output this sentence:
148
+ "The input contains fallacious reasoning consistent with the predicted type(s)."
149
+ - overall_explanation MUST also be specific (2–5 sentences): summarize the reasoning issues and reference the key cue(s).
150
+ - If no fallacy: has_fallacy=false and fallacies=[] and overall_explanation explains briefly why.
151
+
152
+ INPUT:
153
+ {{text}}
154
+
155
+ OUTPUT:"""
156
+
157
+ # /rewrite prompt: returns ONLY a replacement substring for the quote (server does the replacement)
158
+ REWRITE_PROMPT = """You are rewriting a small quoted span inside a larger text.
159
+
160
+ Goal:
161
+ - You MUST propose a replacement for the QUOTE only.
162
+ - The replacement should remove the fallacious reasoning described, while keeping the same tone/style/tense/entities.
163
+ - The replacement MUST be plausible in the surrounding context and should be similar length (roughly +/- 40%).
164
+ - Do NOT change anything outside the quote. Do NOT add new facts not implied by the original.
165
+ - Do NOT introduce new fallacies.
166
+
167
+ Return ONLY valid JSON with this schema:
168
+ {
169
+ "replacement_quote": string,
170
+ "why_this_fix": string
171
+ }
172
+
173
+ Hard rules:
174
+ - Output ONLY JSON. No markdown. No extra text.
175
+ - replacement_quote should be standalone text (no surrounding quotes).
176
+ - why_this_fix: 1–3 sentences, specific.
177
+
178
+ INPUT_TEXT:
179
+ {text}
180
+
181
+ QUOTE_TO_REWRITE:
182
+ {quote}
183
+
184
+ FALLACY_TYPE:
185
+ {fallacy_type}
186
+
187
+ WHY_FALLACIOUS:
188
+ {rationale}
189
+
190
+ OUTPUT:"""
191
+
192
+
193
+ def build_analyze_messages(text: str) -> List[Dict[str, str]]:
194
+ return [
195
+ {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
196
+ {"role": "user", "content": ANALYZE_PROMPT.replace("{text}", text)},
197
+ ]
198
+
199
+
200
+ def build_rewrite_messages(text: str, quote: str, fallacy_type: str, rationale: str) -> List[Dict[str, str]]:
201
+ prompt = REWRITE_PROMPT.format(
202
+ text=text,
203
+ quote=quote,
204
+ fallacy_type=fallacy_type,
205
+ rationale=rationale,
206
+ )
207
+ return [
208
+ {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
209
+ {"role": "user", "content": prompt},
210
+ ]
211
+
212
+
213
  # ============================
214
+ # Logging
215
+ # ============================
216
+ def _log(rid: str, msg: str):
217
+ print(f"[{rid}] {msg}", flush=True)
218
+
219
+
220
+ # ============================
221
+ # Robust JSON extraction
222
+ # ============================
223
+ def stop_at_complete_json(text: str) -> Optional[str]:
224
+ start = text.find("{")
225
+ if start == -1:
226
+ return None
227
+
228
+ depth = 0
229
+ in_str = False
230
+ esc = False
231
+
232
+ for i in range(start, len(text)):
233
+ ch = text[i]
234
+ if in_str:
235
+ if esc:
236
+ esc = False
237
+ elif ch == "\\":
238
+ esc = True
239
+ elif ch == '"':
240
+ in_str = False
241
+ continue
242
+
243
+ if ch == '"':
244
+ in_str = True
245
+ continue
246
+ if ch == "{":
247
+ depth += 1
248
+ elif ch == "}":
249
+ depth -= 1
250
+ if depth == 0:
251
+ return text[start : i + 1]
252
+ return None
253
+
254
+
255
+ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
256
+ cut = stop_at_complete_json(s) or s
257
+ start = cut.find("{")
258
+ end = cut.rfind("}")
259
+ if start == -1 or end == -1 or end <= start:
260
+ return None
261
+ cand = cut[start : end + 1].strip()
262
+ try:
263
+ return json.loads(cand)
264
+ except Exception:
265
+ return None
266
+
267
+
268
+ # ============================
269
+ # Model load
270
+ # ============================
271
+ llm: Optional[Llama] = None
272
+ model_path: Optional[str] = None
273
+ load_error: Optional[str] = None
274
+ loaded_at_ts: Optional[float] = None
275
+
276
+
277
+ def load_llama() -> None:
278
+ global llm, model_path, load_error, loaded_at_ts
279
+
280
+ print("=== FADES startup ===", flush=True)
281
+ print(f"GGUF_REPO_ID={GGUF_REPO_ID}", flush=True)
282
+ print(f"GGUF_FILENAME={GGUF_FILENAME}", flush=True)
283
+ print(f"N_CTX={N_CTX} N_THREADS={N_THREADS} N_BATCH={N_BATCH}", flush=True)
284
+
285
+ try:
286
+ t0 = time.time()
287
+ mp = hf_hub_download(
288
+ repo_id=GGUF_REPO_ID,
289
+ filename=GGUF_FILENAME,
290
+ token=os.getenv("HF_TOKEN"),
291
+ )
292
+ t1 = time.time()
293
+ print(f"✅ GGUF downloaded: {mp} ({t1 - t0:.1f}s)", flush=True)
294
+
295
+ t2 = time.time()
296
+ llm_local = Llama(
297
+ model_path=mp,
298
+ n_ctx=N_CTX,
299
+ n_threads=N_THREADS,
300
+ n_batch=N_BATCH,
301
+ n_gpu_layers=0,
302
+ verbose=False,
303
+ )
304
+ t3 = time.time()
305
+ print(f"✅ Model loaded: ({t3 - t2:.1f}s) n_ctx={N_CTX} threads={N_THREADS} batch={N_BATCH}", flush=True)
306
+
307
+ llm = llm_local
308
+ model_path = mp
309
+ load_error = None
310
+ loaded_at_ts = time.time()
311
+ print("=== Startup OK ===", flush=True)
312
+
313
+ except Exception as e:
314
+ load_error = repr(e)
315
+ print(f"❌ Startup FAILED: {load_error}", flush=True)
316
+
317
+
318
  @app.on_event("startup")
319
  def _startup():
320
+ load_llama()
 
 
 
 
 
 
321
 
322
 
323
  @app.get("/")
 
327
 
328
  @app.get("/health")
329
  def health():
330
+ return {
331
+ "ok": llm is not None and load_error is None,
332
+ "model_loaded": llm is not None,
333
+ "load_error": load_error,
334
+ "gguf_repo": GGUF_REPO_ID,
335
+ "gguf_filename": GGUF_FILENAME,
336
+ "model_path": model_path,
337
+ "n_ctx": N_CTX,
338
+ "n_threads": N_THREADS,
339
+ "n_batch": N_BATCH,
340
+ "loaded_at_ts": loaded_at_ts,
341
+ }
342
 
343
 
344
  # ============================
345
+ # Param selection
346
  # ============================
347
  def pick_params(req: GenParams) -> Dict[str, Any]:
348
  if req.light:
 
367
  if req.top_p is not None:
368
  params["top_p"] = float(req.top_p)
369
 
370
+ # Safety caps
371
  params["max_new_tokens"] = max(1, min(int(params["max_new_tokens"]), 400))
372
  params["temperature"] = max(0.0, min(float(params["temperature"]), 1.5))
373
  params["top_p"] = max(0.05, min(float(params["top_p"]), 1.0))
374
  params["n_batch"] = max(16, min(int(params["n_batch"]), 512))
375
  return params
376
 
377
+ # # ============================
378
+ # # Post-processing: remove template sentence
379
+ # # ============================
380
+ # # This catches the exact sentence + small punctuation variations (case-insensitive).
381
+ # # Also works if the model prefixes rationales with it.
382
+ # _TEMPLATE_RE = re.compile(
383
+ # r"\bthe input contains fallacious reasoning consistent with the predicted type\(s\)\b\.?",
384
+ # flags=re.IGNORECASE,
385
+ # )
386
+
387
+ # def strip_template_sentence(text: str) -> str:
388
+ # if not isinstance(text, str):
389
+ # return ""
390
+ # out = _TEMPLATE_RE.sub("", text)
391
+
392
+ # # Cleanup common leftovers (double spaces, leading punctuation)
393
+ # out = out.replace("..", ".").strip()
394
+ # out = re.sub(r"\s{2,}", " ", out)
395
+ # out = re.sub(r"^\s*[\-–—:;,\.\s]+", "", out).strip()
396
+ # return out
397
+
398
+
399
+
400
+ # ============================
401
+ # Output sanitation / validation
402
+ # ============================
403
+ def _clamp01(x: Any, default: float = 0.5) -> float:
404
+ try:
405
+ v = float(x)
406
+ except Exception:
407
+ return default
408
+ if v < 0.0:
409
+ return 0.0
410
+ if v > 1.0:
411
+ return 1.0
412
+ return v
413
+
414
+
415
+ def _is_allowed_label(lbl: Any) -> bool:
416
+ return isinstance(lbl, str) and lbl in ALLOWED_LABELS and lbl != "none"
417
+
418
+
419
+ def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
420
+ """
421
+ Enforce shape, clamp confidence, drop invalid labels,
422
+ enforce evidence_quotes being substrings.
423
+ """
424
+ has_fallacy = bool(obj.get("has_fallacy", False))
425
+ fallacies_in = obj.get("fallacies", [])
426
+ if not isinstance(fallacies_in, list):
427
+ fallacies_in = []
428
+
429
+ fallacies_out = []
430
+ for f in fallacies_in:
431
+ if not isinstance(f, dict):
432
+ continue
433
+ f_type = f.get("type")
434
+ if not _is_allowed_label(f_type):
435
+ continue
436
+
437
+ conf = _clamp01(f.get("confidence", 0.5))
438
+ # keep 2 decimals for nicer UI
439
+ conf = float(f"{conf:.2f}")
440
+
441
+ ev = f.get("evidence_quotes", [])
442
+ if not isinstance(ev, list):
443
+ ev = []
444
+ ev_clean: List[str] = []
445
+ for q in ev:
446
+ if not isinstance(q, str):
447
+ continue
448
+ qq = q.strip()
449
+ if not qq:
450
+ continue
451
+ # evidence MUST be substring
452
+ if qq in input_text:
453
+ # keep short, but don't hard-cut if it breaks substring matching
454
+ if len(qq) <= 240:
455
+ ev_clean.append(qq)
456
+ else:
457
+ # if too long, try to keep first 240 if still substring (rare); else keep as-is
458
+ short = qq[:240]
459
+ if short in input_text:
460
+ ev_clean.append(short)
461
+ else:
462
+ ev_clean.append(qq)
463
+
464
+ rationale = f.get("rationale")
465
+ if not isinstance(rationale, str):
466
+ rationale = ""
467
+ rationale = rationale.strip()
468
+
469
+ fallacies_out.append(
470
+ {
471
+ "type": f_type,
472
+ "confidence": conf,
473
+ "evidence_quotes": ev_clean[:3],
474
+ "rationale": rationale,
475
+ }
476
+ )
477
+
478
+ overall = obj.get("overall_explanation")
479
+ if not isinstance(overall, str):
480
+ overall = ""
481
+ overall = overall.strip()
482
+
483
+ # If no fallacies survived sanitation, force no-fallacy state
484
+ if len(fallacies_out) == 0:
485
+ has_fallacy = False
486
+
487
+ return {
488
+ "has_fallacy": has_fallacy,
489
+ "fallacies": fallacies_out,
490
+ "overall_explanation": overall,
491
+ }
492
+
493
+
494
+ # ============================
495
+ # Cached generation (task-aware)
496
+ # ============================
497
+ @lru_cache(maxsize=512)
498
+ def _cached_chat_completion(
499
+ task: str,
500
+ payload: str,
501
+ light: bool,
502
+ max_new_tokens: int,
503
+ temperature: float,
504
+ top_p: float,
505
+ n_batch: int,
506
+ ) -> Dict[str, Any]:
507
+ if llm is None:
508
+ return {"ok": False, "error": "model_not_loaded", "detail": load_error}
509
+
510
+ try:
511
+ llm.n_batch = int(n_batch) # type: ignore[attr-defined]
512
+ except Exception:
513
+ pass
514
+
515
+ try:
516
+ data = json.loads(payload)
517
+ except Exception:
518
+ return {"ok": False, "error": "bad_payload"}
519
+
520
+ if task == "analyze":
521
+ messages = build_analyze_messages(data["text"])
522
+ elif task == "rewrite":
523
+ messages = build_rewrite_messages(
524
+ data["text"],
525
+ data["quote"],
526
+ data["fallacy_type"],
527
+ data["rationale"],
528
+ )
529
+ else:
530
+ return {"ok": False, "error": "unknown_task"}
531
+
532
+ out = llm.create_chat_completion(
533
+ messages=messages,
534
+ max_tokens=int(max_new_tokens),
535
+ temperature=float(temperature),
536
+ top_p=float(top_p),
537
+ stream=False,
538
+ )
539
+
540
+ raw = out["choices"][0]["message"]["content"]
541
+ obj = extract_first_json_obj(raw)
542
+ if obj is None:
543
+ return {"ok": False, "error": "json_parse_error", "raw": raw}
544
+
545
+ return {"ok": True, "result": obj}
546
+
547
+
548
+ def _occurrence_index(text: str, sub: str, occurrence: int) -> int:
549
+ if occurrence < 0:
550
+ return -1
551
+ start = 0
552
+ for _ in range(occurrence + 1):
553
+ idx = text.find(sub, start)
554
+ if idx == -1:
555
+ return -1
556
+ start = idx + max(1, len(sub))
557
+ return idx
558
+
559
+
560
+ def _replace_nth(text: str, old: str, new: str, occurrence: int) -> Dict[str, Any]:
561
+ idx = _occurrence_index(text, old, occurrence)
562
+ if idx == -1:
563
+ return {"ok": False, "error": "quote_not_found"}
564
+ return {
565
+ "ok": True,
566
+ "rewritten_text": text[:idx] + new + text[idx + len(old) :],
567
+ "start_char": idx,
568
+ "end_char": idx + len(new),
569
+ "old_start_char": idx,
570
+ "old_end_char": idx + len(old),
571
+ }
572
+
573
 
574
  # ============================
575
  # Routes
 
577
  @app.post("/analyze")
578
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
579
  rid = uuid.uuid4().hex[:10]
 
580
  t0 = time.time()
581
 
582
+ _log(rid, f"📩 /analyze received (light={req.light}) chars={len(req.text) if req.text else 0}")
583
 
584
+ if not req.text or not req.text.strip():
585
+ return {"ok": False, "error": "empty_text"}
 
586
 
587
+ params = pick_params(req)
588
+ _log(
589
+ rid,
590
+ f"⚙️ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
591
+ )
592
 
593
  payload = json.dumps({"text": req.text}, ensure_ascii=False)
594
 
595
+ async with GEN_LOCK:
596
+ t_lock = time.time()
597
+
598
+ _log(rid, "🧠 Generating analyze...")
599
+ t_gen0 = time.time()
600
+ res = _cached_chat_completion(
601
+ "analyze",
602
+ payload,
603
+ bool(req.light),
604
+ int(params["max_new_tokens"]),
605
+ float(params["temperature"]),
606
+ float(params["top_p"]),
607
+ int(params["n_batch"]),
608
+ )
609
+ t_gen1 = time.time()
 
610
 
611
  elapsed_total = time.time() - t0
612
+ elapsed_lock = time.time() - t_lock
613
 
614
  if not res.get("ok"):
615
+ _log(rid, f"❌ /analyze failed: {res.get('error')}")
616
  return {
617
  **res,
618
  "meta": {
 
628
  },
629
  }
630
 
631
+ # sanitize output for stability (substrings, labels, confidence clamp)
632
+ clean = sanitize_analyze_output(res["result"], req.text)
633
 
634
+ _log(rid, f"✅ /analyze ok fallacies={len(clean.get('fallacies', []))} total={elapsed_total:.2f}s")
635
  return {
636
  "ok": True,
637
  "result": clean,
 
656
  @app.post("/rewrite")
657
  async def rewrite(req: RewriteRequest) -> Dict[str, Any]:
658
  rid = uuid.uuid4().hex[:10]
 
659
  t0 = time.time()
660
 
661
+ _log(
662
+ rid,
663
+ f"📩 /rewrite received (light={req.light}) text_chars={len(req.text) if req.text else 0} quote_chars={len(req.quote) if req.quote else 0}",
664
  )
665
 
666
+ if not req.text or not req.text.strip():
667
+ return {"ok": False, "error": "empty_text"}
668
+ if not req.quote or not req.quote.strip():
669
+ return {"ok": False, "error": "empty_quote"}
 
670
 
671
  quote = req.quote.strip()
672
  occurrence = int(req.occurrence or 0)
673
 
674
+ # validate quote existence early
675
+ if _occurrence_index(req.text, quote, occurrence) == -1:
676
+ return {"ok": False, "error": "quote_not_found", "detail": {"occurrence": occurrence}}
677
 
678
+ params = pick_params(req)
679
+ # rewrite generally needs a bit more room than light analyze if you want fluent replacements
680
+ # (still controllable by request overrides)
681
+ if req.light and req.max_new_tokens is None:
682
+ params["max_new_tokens"] = max(params["max_new_tokens"], 80)
683
+
684
+ _log(
685
+ rid,
686
+ f"⚙️ Params: max_new_tokens={params['max_new_tokens']} temp={params['temperature']} top_p={params['top_p']} n_batch={params['n_batch']}",
687
+ )
688
 
689
  payload = json.dumps(
690
  {
 
696
  ensure_ascii=False,
697
  )
698
 
699
+ async with GEN_LOCK:
700
+ t_lock = time.time()
701
+
702
+ _log(rid, "🧠 Generating rewrite replacement_quote...")
703
+ t_gen0 = time.time()
704
+ res = _cached_chat_completion(
705
+ "rewrite",
706
+ payload,
707
+ bool(req.light),
708
+ int(params["max_new_tokens"]),
709
+ float(params["temperature"]),
710
+ float(params["top_p"]),
711
+ int(params["n_batch"]),
712
+ )
713
+ t_gen1 = time.time()
 
714
 
715
  elapsed_total = time.time() - t0
716
+ elapsed_lock = time.time() - t_lock
717
 
718
  if not res.get("ok"):
719
+ _log(rid, f"❌ /rewrite failed: {res.get('error')}")
720
  return {
721
  **res,
722
  "meta": {
 
732
  },
733
  }
734
 
735
+ obj = res["result"]
736
+ if not isinstance(obj, dict):
737
+ return {"ok": False, "error": "bad_rewrite_output"}
738
+
739
+ replacement = obj.get("replacement_quote")
740
+ if not isinstance(replacement, str):
741
+ return {"ok": False, "error": "missing_replacement_quote", "raw": obj}
742
 
743
+ replacement = replacement.strip()
744
+ if not replacement:
745
+ return {"ok": False, "error": "empty_replacement_quote", "raw": obj}
 
 
 
746
 
747
+ why = obj.get("why_this_fix")
748
+ if not isinstance(why, str):
749
+ why = ""
750
+ why = why.strip()
751
 
752
+ # server-side enforced: ONLY the quote is changed
753
+ rep = _replace_nth(req.text, quote, replacement, occurrence)
754
+ if not rep.get("ok"):
755
+ return {"ok": False, "error": rep.get("error", "replace_failed")}
756
 
757
+ _log(rid, f"✅ /rewrite ok total={elapsed_total:.2f}s")
758
  return {
759
  "ok": True,
760
  "result": {
model_runtime.py DELETED
@@ -1,129 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- from functools import lru_cache
5
- from typing import Any, Dict, Optional
6
-
7
- from huggingface_hub import hf_hub_download
8
- from llama_cpp import Llama
9
-
10
- from prompts import build_analyze_messages, build_rewrite_messages
11
- from utils import extract_first_json_obj
12
-
13
- llm: Optional[Llama] = None
14
- model_path: Optional[str] = None
15
- load_error: Optional[str] = None
16
- loaded_at_ts: Optional[float] = None
17
-
18
- def load_llama(
19
- gguf_repo_id: str,
20
- gguf_filename: str,
21
- n_ctx: int,
22
- n_threads: int,
23
- n_batch: int,
24
- ) -> None:
25
- global llm, model_path, load_error, loaded_at_ts
26
-
27
- print("=== FADES startup ===", flush=True)
28
- print(f"GGUF_REPO_ID={gguf_repo_id}", flush=True)
29
- print(f"GGUF_FILENAME={gguf_filename}", flush=True)
30
- print(f"N_CTX={n_ctx} N_THREADS={n_threads} N_BATCH={n_batch}", flush=True)
31
-
32
- try:
33
- t0 = time.time()
34
- mp = hf_hub_download(
35
- repo_id=gguf_repo_id,
36
- filename=gguf_filename,
37
- token=os.getenv("HF_TOKEN"),
38
- )
39
- t1 = time.time()
40
- print(f"✅ GGUF downloaded: {mp} ({t1 - t0:.1f}s)", flush=True)
41
-
42
- t2 = time.time()
43
- llm_local = Llama(
44
- model_path=mp,
45
- n_ctx=n_ctx,
46
- n_threads=n_threads,
47
- n_batch=n_batch,
48
- n_gpu_layers=0,
49
- verbose=False,
50
- )
51
- t3 = time.time()
52
- print(f"✅ Model loaded: ({t3 - t2:.1f}s) n_ctx={n_ctx} threads={n_threads} batch={n_batch}", flush=True)
53
-
54
- llm = llm_local
55
- model_path = mp
56
- load_error = None
57
- loaded_at_ts = time.time()
58
- print("=== Startup OK ===", flush=True)
59
- except Exception as e:
60
- load_error = repr(e)
61
- llm = None
62
- print(f"❌ Startup FAILED: {load_error}", flush=True)
63
-
64
- def get_health(gguf_repo_id: str, gguf_filename: str, n_ctx: int, n_threads: int, n_batch: int) -> Dict[str, Any]:
65
- return {
66
- "ok": llm is not None and load_error is None,
67
- "model_loaded": llm is not None,
68
- "load_error": load_error,
69
- "gguf_repo": gguf_repo_id,
70
- "gguf_filename": gguf_filename,
71
- "model_path": model_path,
72
- "n_ctx": n_ctx,
73
- "n_threads": n_threads,
74
- "n_batch": n_batch,
75
- "loaded_at_ts": loaded_at_ts,
76
- }
77
-
78
- @lru_cache(maxsize=512)
79
- def cached_chat_completion(
80
- task: str,
81
- payload: str,
82
- max_new_tokens: int,
83
- temperature: float,
84
- top_p: float,
85
- n_batch: int,
86
- ) -> Dict[str, Any]:
87
- """
88
- Cached llama chat completion.
89
- NOTE: GEN_LOCK is managed by FastAPI routes (outside).
90
- """
91
- if llm is None:
92
- return {"ok": False, "error": "model_not_loaded", "detail": load_error}
93
-
94
- try:
95
- llm.n_batch = int(n_batch) # type: ignore[attr-defined]
96
- except Exception:
97
- pass
98
-
99
- try:
100
- data = json.loads(payload)
101
- except Exception:
102
- return {"ok": False, "error": "bad_payload"}
103
-
104
- if task == "analyze":
105
- messages = build_analyze_messages(data["text"])
106
- elif task == "rewrite":
107
- messages = build_rewrite_messages(
108
- data["text"],
109
- data["quote"],
110
- data["fallacy_type"],
111
- data["rationale"],
112
- )
113
- else:
114
- return {"ok": False, "error": "unknown_task"}
115
-
116
- out = llm.create_chat_completion(
117
- messages=messages,
118
- max_tokens=int(max_new_tokens),
119
- temperature=float(temperature),
120
- top_p=float(top_p),
121
- stream=False,
122
- )
123
-
124
- raw = out["choices"][0]["message"]["content"]
125
- obj = extract_first_json_obj(raw)
126
- if obj is None:
127
- return {"ok": False, "error": "json_parse_error", "raw": raw}
128
-
129
- return {"ok": True, "result": obj}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompts.py DELETED
@@ -1,113 +0,0 @@
1
- from typing import Dict, List
2
-
3
- ALLOWED_LABELS = [
4
- "none",
5
- "faulty generalization",
6
- "false causality",
7
- "circular reasoning",
8
- "ad populum",
9
- "ad hominem",
10
- "fallacy of logic",
11
- "appeal to emotion",
12
- "false dilemma",
13
- "equivocation",
14
- "fallacy of extension",
15
- "fallacy of relevance",
16
- "fallacy of credibility",
17
- "miscellaneous",
18
- "intentional",
19
- ]
20
-
21
- LABELS_STR = ", ".join([f'"{x}"' for x in ALLOWED_LABELS])
22
-
23
- # Stronger /analyze prompt: forces specificity and forbids the "template" sentence
24
- ANALYZE_PROMPT = f"""You are a fallacy detection assistant.
25
-
26
- You MUST choose labels ONLY from this list (exact string):
27
- {LABELS_STR}
28
-
29
- You MUST return ONLY valid JSON with this schema:
30
- {{
31
- "has_fallacy": boolean,
32
- "fallacies": [
33
- {{
34
- "type": string,
35
- "confidence": number,
36
- "evidence_quotes": [string],
37
- "rationale": string
38
- }}
39
- ],
40
- "overall_explanation": string
41
- }}
42
-
43
- Hard rules:
44
- - Output ONLY JSON. No markdown. No extra text.
45
- - evidence_quotes MUST be verbatim substrings copied from the input text (no paraphrase).
46
- - Keep each evidence quote short (prefer 1–2 sentences; max 240 chars).
47
- - confidence MUST be a real probability between 0.0 and 1.0 (use 2 decimals).
48
- It MUST NOT be always the same across examples. Calibrate it.
49
- - The rationale MUST be specific to the evidence (2–4 sentences):
50
- Explain (1) what the quote claims, (2) why that matches the fallacy label,
51
- (3) what logical step is invalid or missing.
52
- DO NOT use generic filler. Do NOT reuse stock phrases.
53
- - If no fallacy: has_fallacy=false and fallacies=[] and overall_explanation explains briefly why.
54
- INPUT:
55
- {{text}}
56
-
57
- OUTPUT:"""
58
-
59
- # /rewrite prompt: returns ONLY a replacement substring for the quote (server does the replacement)
60
- REWRITE_PROMPT = """You are rewriting a small quoted span inside a larger text.
61
-
62
- Goal:
63
- - You MUST propose a replacement for the QUOTE only.
64
- - The replacement should remove the fallacious reasoning described, while keeping the same tone/style/tense/entities.
65
- - The replacement MUST be plausible in the surrounding context and should be similar length (roughly +/- 40%).
66
- - Do NOT change anything outside the quote. Do NOT add new facts not implied by the original.
67
- - Do NOT introduce new fallacies.
68
-
69
- Return ONLY valid JSON with this schema:
70
- {
71
- "replacement_quote": string,
72
- "why_this_fix": string
73
- }
74
-
75
- Hard rules:
76
- - Output ONLY JSON. No markdown. No extra text.
77
- - replacement_quote should be standalone text (no surrounding quotes).
78
- - why_this_fix: 1–3 sentences, specific.
79
-
80
- INPUT_TEXT:
81
- {text}
82
-
83
- QUOTE_TO_REWRITE:
84
- {quote}
85
-
86
- FALLACY_TYPE:
87
- {fallacy_type}
88
-
89
- WHY_FALLACIOUS:
90
- {rationale}
91
-
92
- OUTPUT:"""
93
-
94
-
95
- def build_analyze_messages(text: str) -> List[Dict[str, str]]:
96
- return [
97
- {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
98
- {"role": "user", "content": ANALYZE_PROMPT.replace("{text}", text)},
99
- ]
100
-
101
-
102
- def build_rewrite_messages(text: str, quote: str, fallacy_type: str, rationale: str) -> List[Dict[str, str]]:
103
- prompt = (
104
- REWRITE_PROMPT
105
- .replace("<<TEXT>>", text)
106
- .replace("<<QUOTE>>", quote)
107
- .replace("<<FALLACY_TYPE>>", fallacy_type)
108
- .replace("<<RATIONALE>>", rationale)
109
- )
110
- return [
111
- {"role": "system", "content": "Return only JSON. Exactly one JSON object. No extra text."},
112
- {"role": "user", "content": prompt},
113
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,171 +0,0 @@
1
- import json
2
- import re
3
- from typing import Any, Dict, Optional, List
4
- from prompts import ALLOWED_LABELS
5
-
6
- # ----------------------------
7
- # Robust JSON extraction
8
- # ----------------------------
9
- def stop_at_complete_json(text: str) -> Optional[str]:
10
- start = text.find("{")
11
- if start == -1:
12
- return None
13
-
14
- depth = 0
15
- in_str = False
16
- esc = False
17
-
18
- for i in range(start, len(text)):
19
- ch = text[i]
20
- if in_str:
21
- if esc:
22
- esc = False
23
- elif ch == "\\":
24
- esc = True
25
- elif ch == '"':
26
- in_str = False
27
- continue
28
-
29
- if ch == '"':
30
- in_str = True
31
- continue
32
- if ch == "{":
33
- depth += 1
34
- elif ch == "}":
35
- depth -= 1
36
- if depth == 0:
37
- return text[start : i + 1]
38
- return None
39
-
40
-
41
- def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
42
- cut = stop_at_complete_json(s) or s
43
- start = cut.find("{")
44
- end = cut.rfind("}")
45
- if start == -1 or end == -1 or end <= start:
46
- return None
47
- cand = cut[start : end + 1].strip()
48
- try:
49
- return json.loads(cand)
50
- except Exception:
51
- return None
52
-
53
-
54
- # ----------------------------
55
- # Post-processing: remove template sentence
56
- # ----------------------------
57
- _TEMPLATE_RE = re.compile(
58
- r"\bthe input contains fallacious reasoning consistent with the predicted type\(s\)\b\.?",
59
- flags=re.IGNORECASE,
60
- )
61
-
62
- def strip_template_sentence(text: str) -> str:
63
- if not isinstance(text, str):
64
- return ""
65
- out = _TEMPLATE_RE.sub("", text)
66
- out = out.replace("..", ".").strip()
67
- out = re.sub(r"\s{2,}", " ", out)
68
- out = re.sub(r"^\s*[\-–—:;,\.\s]+", "", out).strip()
69
- return out
70
-
71
-
72
- # ----------------------------
73
- # Output sanitation / validation
74
- # ----------------------------
75
- def _clamp01(x: Any, default: float = 0.5) -> float:
76
- try:
77
- v = float(x)
78
- except Exception:
79
- return default
80
- return 0.0 if v < 0.0 else (1.0 if v > 1.0 else v)
81
-
82
-
83
- def _is_allowed_label(lbl: Any) -> bool:
84
- return isinstance(lbl, str) and lbl in ALLOWED_LABELS and lbl != "none"
85
-
86
-
87
- def sanitize_analyze_output(obj: Dict[str, Any], input_text: str) -> Dict[str, Any]:
88
- has_fallacy = bool(obj.get("has_fallacy", False))
89
- fallacies_in = obj.get("fallacies", [])
90
- if not isinstance(fallacies_in, list):
91
- fallacies_in = []
92
-
93
- fallacies_out = []
94
- for f in fallacies_in:
95
- if not isinstance(f, dict):
96
- continue
97
- f_type = f.get("type")
98
- if not _is_allowed_label(f_type):
99
- continue
100
-
101
- conf = _clamp01(f.get("confidence", 0.5))
102
- conf = float(f"{conf:.2f}")
103
-
104
- ev = f.get("evidence_quotes", [])
105
- if not isinstance(ev, list):
106
- ev = []
107
-
108
- ev_clean: List[str] = []
109
- for q in ev:
110
- if not isinstance(q, str):
111
- continue
112
- qq = q.strip()
113
- if not qq:
114
- continue
115
- if qq in input_text:
116
- if len(qq) <= 240:
117
- ev_clean.append(qq)
118
- else:
119
- short = qq[:240]
120
- ev_clean.append(short if short in input_text else qq)
121
-
122
- rationale = strip_template_sentence(str(f.get("rationale", "")).strip())
123
-
124
- fallacies_out.append(
125
- {
126
- "type": f_type,
127
- "confidence": conf,
128
- "evidence_quotes": ev_clean[:3],
129
- "rationale": rationale,
130
- }
131
- )
132
-
133
- overall = strip_template_sentence(str(obj.get("overall_explanation", "")).strip())
134
-
135
- if len(fallacies_out) == 0:
136
- has_fallacy = False
137
-
138
- return {
139
- "has_fallacy": has_fallacy,
140
- "fallacies": fallacies_out,
141
- "overall_explanation": overall,
142
- }
143
-
144
-
145
- # ----------------------------
146
- # Replace helpers
147
- # ----------------------------
148
- def occurrence_index(text: str, sub: str, occurrence: int) -> int:
149
- if occurrence < 0:
150
- return -1
151
- start = 0
152
- for _ in range(occurrence + 1):
153
- idx = text.find(sub, start)
154
- if idx == -1:
155
- return -1
156
- start = idx + max(1, len(sub))
157
- return idx
158
-
159
-
160
- def replace_nth(text: str, old: str, new: str, occurrence: int) -> Dict[str, Any]:
161
- idx = occurrence_index(text, old, occurrence)
162
- if idx == -1:
163
- return {"ok": False, "error": "quote_not_found"}
164
- return {
165
- "ok": True,
166
- "rewritten_text": text[:idx] + new + text[idx + len(old) :],
167
- "start_char": idx,
168
- "end_char": idx + len(new),
169
- "old_start_char": idx,
170
- "old_end_char": idx + len(old),
171
- }