File size: 14,548 Bytes
5a8ecdf
 
e19317d
df0ce09
e19317d
df0ce09
d0d7bc6
d9b52fd
8d0988b
df0ce09
 
 
81e2856
 
5a8ecdf
d0d7bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0ce09
d48c265
d0d7bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a8ecdf
df0ce09
66ca5c9
 
81e2856
66ca5c9
8d0988b
 
 
 
 
 
 
 
 
 
 
 
 
 
81e2856
df0ce09
 
 
 
81e2856
 
df0ce09
 
 
 
 
 
 
 
 
 
 
 
 
 
d9b52fd
 
df0ce09
d9b52fd
 
 
 
 
df0ce09
d9b52fd
 
 
 
81e2856
d9b52fd
 
 
 
 
 
 
 
 
 
 
 
 
 
81e2856
df0ce09
81e2856
 
 
 
 
 
 
 
 
 
 
 
d9b52fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0ce09
d0d7bc6
d9b52fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07cd6a9
df0ce09
81e2856
07cd6a9
81e2856
d9b52fd
 
 
 
 
 
 
 
 
 
df0ce09
 
 
 
 
 
d0d7bc6
df0ce09
 
 
 
 
 
 
d0d7bc6
 
 
 
 
df0ce09
 
 
d48c265
df0ce09
 
 
 
 
 
d0d7bc6
 
df0ce09
 
 
 
 
 
d0d7bc6
df0ce09
 
 
 
 
d0d7bc6
df0ce09
 
 
d48c265
df0ce09
d0d7bc6
d48c265
df0ce09
d0d7bc6
 
 
 
 
 
df0ce09
 
81e2856
df0ce09
81e2856
df0ce09
 
d0d7bc6
 
 
 
 
 
 
 
 
81e2856
df0ce09
d0d7bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e2856
df0ce09
81e2856
d0d7bc6
 
81e2856
df0ce09
 
 
 
66ca5c9
df0ce09
 
 
 
 
5a8ecdf
 
 
df0ce09
 
5a8ecdf
df0ce09
 
 
 
 
81e2856
df0ce09
81e2856
df0ce09
 
 
d0d7bc6
 
 
 
 
 
 
 
07cd6a9
df0ce09
07cd6a9
d0d7bc6
07cd6a9
df0ce09
 
 
d48c265
d0d7bc6
 
 
 
 
d48c265
d9b52fd
d0d7bc6
df0ce09
 
d0d7bc6
81e2856
d9b52fd
 
 
df0ce09
81e2856
df0ce09
 
 
 
 
 
 
afd3da3
df0ce09
 
1f23e23
df0ce09
d9b52fd
1f23e23
d0d7bc6
 
df0ce09
d0d7bc6
 
 
66ca5c9
d9b52fd
df0ce09
 
66ca5c9
 
df0ce09
 
66ca5c9
df0ce09
 
d0d7bc6
 
66ca5c9
 
 
df0ce09
 
 
 
81e2856
d0d7bc6
 
 
 
 
d9b52fd
d0d7bc6
df0ce09
d9b52fd
 
 
 
df0ce09
d0d7bc6
 
df0ce09
 
 
 
d0d7bc6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import os
import json
import time
import math
import asyncio
from functools import lru_cache
from typing import Any, Dict, List
from utils import extract_json_obj_robust, sanitize_analyze_output
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama

# ----------------------------
# Config (env overridable)
# ----------------------------
def _int_env(name: str, default: int) -> int:
    try:
        return int(os.getenv(name, str(default)))
    except Exception:
        return default

def _bool_env(name: str, default: bool) -> bool:
    v = os.getenv(name, None)
    if v is None:
        return default
    return v.strip().lower() in {"1", "true", "yes", "y", "on"}

ENABLE_FULL_CONFIDENCE = _bool_env("ENABLE_FULL_CONFIDENCE", True)
USE_FLASH_ATTN = _bool_env("USE_FLASH_ATTN", True)

N_BATCH = _int_env("N_BATCH", 1024)
N_THREADS = _int_env("N_THREADS", 6)
N_CTX = _int_env("N_CTX", 1024)

# For CPU builds, keep this at 0
N_GPU_LAYERS = _int_env("N_GPU_LAYERS", 0)

# ----------------------------
# Cache dir (portable)
# ----------------------------
# Colab Drive (optional)
DRIVE_CACHE_DIR = "/content/drive/MyDrive/FADES_Models_Cache"

# HF Spaces / Docker-friendly cache (your Dockerfile sets these to /data/...)
HF_CACHE = (
    os.getenv("HUGGINGFACE_HUB_CACHE")
    or (os.path.join(os.getenv("HF_HOME", "/data"), ".cache", "huggingface", "hub"))
)

# Choose best available cache dir
if os.path.exists("/content/drive"):
    CACHE_DIR = DRIVE_CACHE_DIR
else:
    CACHE_DIR = HF_CACHE or "/tmp/hf_cache"

try:
    os.makedirs(CACHE_DIR, exist_ok=True)
except Exception:
    pass

GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")

GEN_LOCK = asyncio.Lock()
app = FastAPI(title="FADES Fallacy Detector API (Final)")

# ============================
# CORS (for browser front-ends)
# ============================
_CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
    allow_origins = ["*"]
else:
    allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()]

app.add_middleware(
    CORSMiddleware,
    allow_origins=allow_origins,
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

ALLOWED_LABELS = [
    "none", "faulty generalization", "false causality", "circular reasoning",
    "ad populum", "ad hominem", "fallacy of logic", "appeal to emotion",
    "false dilemma", "equivocation", "fallacy of extension",
    "fallacy of relevance", "fallacy of credibility", "miscellaneous", "intentional"
]

LABEL_MAPPING = {
    "none": ["none"],
    "faulty": ["faulty generalization"],
    "false": ["false causality", "false dilemma"],
    "circular": ["circular reasoning"],
    "ad": ["ad populum", "ad hominem"],
    "fallacy": ["fallacy of logic", "extension", "relevance", "credibility"],
    "appeal": ["appeal to emotion"],
    "equivocation": ["equivocation"],
    "miscellaneous": ["miscellaneous"],
    "intentional": ["intentional"]
}


ANALYZE_SYS_PROMPT = """You are a logic expert. Detect logical fallacies in the given text.
OUTPUT JSON ONLY. No markdown. No extra keys. No commentary outside JSON.

IMPORTANT:
- The text can contain MULTIPLE fallacies. Return ALL that apply.
- If there are NO fallacies, set "has_fallacy": false and "fallacies": [].
- "evidence_quotes" MUST be the SHORTEST exact span(s) from the input that justify the fallacy.
  Do NOT quote the whole text. Prefer 1 short quote; at most 3 quotes.

LABELS:
Use ONLY these labels: {labels}
Do NOT invent labels.
Do NOT output "none" as a fallacy item.

RATIONALE QUALITY (VERY IMPORTANT):
Each fallacy "rationale" MUST be directly tied to THIS input text, and MUST NOT be generic.
Structure each rationale like this (2–3 sentences max):
1) Restate the specific claim from the input in your own words AND anchor it to the exact quote(s).
2) Explain why that claim matches the fallacy, referencing what is missing or what is assumed.
3) If relevant, mention a concrete cue in the text.

OVERALL EXPLANATION (MUST reference the input):
- First: a quick recap of the detected fallacies (1 short sentence).
- Then: a general explanation of why these fallacies are risky IN THIS TEXT.
- If no fallacy: briefly explain why the reasoning is acceptable / what would be needed to call it fallacious.

CONFIDENCE:
"confidence" is between 0.0 and 1.0.

JSON SCHEMA:
{{
  "has_fallacy": boolean,
  "fallacies": [
    {{
      "type": string,
      "confidence": number,
      "evidence_quotes": [string],
      "rationale": string
    }}
  ],
  "overall_explanation": string
}}

EXAMPLES (style guide — copy this style):

Input: "If we allow remote work, productivity will collapse and the company will fail."
Output:
{{
  "has_fallacy": true,
  "fallacies": [{{
    "type": "false causality",
    "confidence": 0.86,
    "evidence_quotes": ["If we allow remote work, productivity will collapse", "the company will fail"],
    "rationale": "The input implies that allowing remote work will directly cause productivity to collapse and lead to company failure (quotes: 'productivity will collapse', 'the company will fail') without supporting evidence. It treats a speculative outcome as a guaranteed causal chain, jumping from a policy change to extreme consequences."
  }}],
  "overall_explanation": "Recap: false causality. Risk: the text presents a shaky cause-and-effect chain as certain, which can push decisions based on fear rather than evidence and ignore alternative explanations."
}}
"""

REWRITE_SYS_PROMPT = """You are a careful editor. Rewrite the text to REMOVE the logical fallacy while PRESERVING the original meaning as much as possible.

Context:
- Predicted fallacy type: {fallacy_type}
- Rationale: {rationale}

GOAL:
- Keep the same overall intent, but soften / qualify claims so the reasoning is no longer fallacious.
- Avoid absolute language ("always", "everyone", "no one") unless fully justified in the text.
- Replace overgeneralizations with reasonable qualifiers ("some", "often", "can", "in some cases").
- If the issue is causality, add uncertainty or evidence requirements ("may contribute", "could be related", "without evidence we can't conclude").
- If the issue is a false dilemma, add alternatives and nuance.
- If the issue is ad hominem / credibility attacks, remove personal attacks and focus on claims/evidence.

OUTPUT JSON ONLY (no markdown):
{{
  "rewritten_text": string,
  "why_this_fix": string
}}

The "why_this_fix" must be short (1-2 sentences) and explain what you changed to remove the fallacy.

EXAMPLE:
Input idea: "All blond women are pretty."
Output:
{{
  "rewritten_text": "Some blond women can be very pretty, but attractiveness varies from person to person.",
  "why_this_fix": "It removes the absolute generalization and replaces it with a qualified statement that doesn't stereotype an entire group."
}}
"""

def analyze_alternatives(start_index: int, top_logprobs_list: List[Dict[str, float]]) -> Dict[str, float]:
    if start_index < 0 or start_index >= len(top_logprobs_list):
        return {}
    candidates = top_logprobs_list[start_index]

    distribution: Dict[str, float] = {}
    for token, logprob in candidates.items():
        clean_tok = str(token).replace(" ", "").lower().strip()
        prob = math.exp(logprob)

        matched = False
        for key, group in LABEL_MAPPING.items():
            if clean_tok.startswith(key):
                group_name = (
                    f"{key.capitalize()} ({'/'.join([g.split()[-1] for g in group])})"
                    if len(group) > 1
                    else group[0].title()
                )
                distribution[group_name] = distribution.get(group_name, 0.0) + prob
                matched = True
                break

        if not matched:
            distribution["_other_"] = distribution.get("_other_", 0.0) + prob

    return {k: round(v, 4) for k, v in distribution.items() if v > 0.001}

def extract_label_info(target_label: str, tokens: List[str], logprobs: List[float], top_logprobs: List[Dict]) -> Dict:
    if not target_label:
        return {"conf": 0.0, "dist": {}}

    target_clean = target_label.lower().strip()
    current_text = ""
    start_index = -1

    for i, token in enumerate(tokens):
        tok_str = str(token) if not isinstance(token, bytes) else token.decode("utf-8", errors="ignore")
        current_text += tok_str
        if target_clean in current_text.lower() and start_index == -1:
            start_index = max(0, i - 5)
            for j in range(start_index, i + 1):
                t_s = str(tokens[j]).lower()
                if target_clean and target_clean[0] in t_s:
                    start_index = j
                    break
            break

    conf = 0.0
    dist: Dict[str, float] = {}

    if start_index != -1:
        valid = [
            math.exp(logprobs[k])
            for k in range(start_index, min(len(logprobs), start_index + 3))
            if logprobs[k] is not None
        ]
        conf = round(sum(valid) / len(valid), 4) if valid else 0.0
        if top_logprobs:
            dist = analyze_alternatives(start_index, top_logprobs)

    return {"conf": conf, "dist": dist}

@lru_cache(maxsize=1)
def get_model():
    print("📦 Loading Model...")
    model_path = hf_hub_download(
        repo_id=GGUF_REPO_ID,
        filename=GGUF_FILENAME,
        cache_dir=CACHE_DIR,
        repo_type="model",
    )

    # Try with flash_attn + gpu layers (if supported), otherwise fallback safely (CPU)
    try:
        llm = Llama(
            model_path=model_path,
            n_ctx=N_CTX,
            n_threads=N_THREADS,
            n_batch=N_BATCH,
            verbose=False,
            n_gpu_layers=N_GPU_LAYERS,
            flash_attn=USE_FLASH_ATTN,
            logits_all=ENABLE_FULL_CONFIDENCE,
        )
        return llm
    except TypeError:
        # Older builds may not accept flash_attn
        llm = Llama(
            model_path=model_path,
            n_ctx=N_CTX,
            n_threads=N_THREADS,
            n_batch=N_BATCH,
            verbose=False,
            n_gpu_layers=0,
            logits_all=ENABLE_FULL_CONFIDENCE,
        )
        return llm
    except Exception as e:
        print(f"❌ Error while loading model: {e}")
        raise

class AnalyzeRequest(BaseModel):
    text: str
    max_new_tokens: int = 300
    temperature: float = 0.1

class RewriteRequest(BaseModel):
    text: str
    fallacy_type: str
    rationale: str
    max_new_tokens: int = 300

@app.get("/health")
def health():
    get_model()
    return {"status": "ok"}

@app.post("/analyze")
async def analyze(req: AnalyzeRequest):
    llm = get_model()
    system_prompt = ANALYZE_SYS_PROMPT.format(labels=", ".join(ALLOWED_LABELS))
    prompt = f"[INST] {system_prompt}\n\nINPUT TEXT:\n{req.text} [/INST]"

    req_logprobs = 20 if ENABLE_FULL_CONFIDENCE else None

    async with GEN_LOCK:
        start_time = time.time()
        output = llm(
            prompt,
            max_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=0.95,
            repeat_penalty=1.15,
            stop=["</s>", "```"],
            echo=False,
            logprobs=req_logprobs,
        )
        gen_time = time.time() - start_time

    raw_text = output["choices"][0]["text"]

    tokens = []
    logprobs = []
    top_logprobs = []

    if ENABLE_FULL_CONFIDENCE and "logprobs" in output["choices"][0]:
        lp_data = output["choices"][0]["logprobs"]
        tokens = lp_data.get("tokens", [])
        logprobs = lp_data.get("token_logprobs", [])
        top_logprobs = lp_data.get("top_logprobs", [])

    parsed_obj = extract_json_obj_robust(raw_text)
    result_json: Dict[str, Any] = {}
    success = False
    technical_confidence = 0.0
    label_distribution: Dict[str, float] = {}

    if parsed_obj is not None:
        # Enforce schema + clean common template artifacts
        result_json = sanitize_analyze_output(parsed_obj, req.text)
        success = True

        if result_json.get("has_fallacy") and result_json.get("fallacies"):
            for fallacy in result_json["fallacies"]:
                d_type = fallacy.get("type", "")
                if ENABLE_FULL_CONFIDENCE:
                    info = extract_label_info(d_type, tokens, logprobs, top_logprobs)
                    spec_conf = info["conf"]
                    label_distribution = info["dist"]

                    fallacy["technical_confidence"] = spec_conf
                    fallacy["alternatives"] = label_distribution

                    declared = fallacy.get("confidence", 0.8)
                    fallacy["confidence"] = round((float(declared) + float(spec_conf)) / 2, 2)

                    if technical_confidence == 0.0:
                        technical_confidence = spec_conf
        else:
            if ENABLE_FULL_CONFIDENCE:
                info = extract_label_info("has_fallacy", tokens, logprobs, top_logprobs)
                label_distribution = info["dist"]

    else:
        result_json = {"error": "JSON Error", "raw": raw_text}
        success = False

    return {
        "ok": success,
        "result": result_json,
        "meta": {
            "tech_conf": technical_confidence,
            "distribution": label_distribution,
            "time": round(gen_time, 2),
        },
    }

@app.post("/rewrite")
async def rewrite(req: RewriteRequest):
    llm = get_model()
    system_prompt = REWRITE_SYS_PROMPT.format(fallacy_type=req.fallacy_type, rationale=req.rationale)
    prompt = f"[INST] {system_prompt}\n\nTEXT TO FIX:\n{req.text} [/INST]"
    async with GEN_LOCK:
        output = llm(
            prompt,
            max_tokens=req.max_new_tokens,
            temperature=0.7,
            repeat_penalty=1.1,
            stop=["</s>", "```"],
        )
    try:
        parsed = extract_json_obj_robust(output["choices"][0]["text"])
        if parsed is None:
            raise ValueError("json_parse_failed")
        res = parsed
        ok = True
    except Exception:
        res = {"raw": output["choices"][0]["text"]}
        ok = False
    return {"ok": ok, "result": res}

if __name__ == "__main__":
    # Works both locally + HF Spaces
    port = _int_env("PORT", 7860)
    uvicorn.run(app, host="0.0.0.0", port=port)