import os import json import time import math import asyncio from functools import lru_cache from typing import Any, Dict, List from utils import extract_json_obj_robust, sanitize_analyze_output from fastapi.middleware.cors import CORSMiddleware import uvicorn from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import hf_hub_download from llama_cpp import Llama # ---------------------------- # Config (env overridable) # ---------------------------- def _int_env(name: str, default: int) -> int: try: return int(os.getenv(name, str(default))) except Exception: return default def _bool_env(name: str, default: bool) -> bool: v = os.getenv(name, None) if v is None: return default return v.strip().lower() in {"1", "true", "yes", "y", "on"} ENABLE_FULL_CONFIDENCE = _bool_env("ENABLE_FULL_CONFIDENCE", True) USE_FLASH_ATTN = _bool_env("USE_FLASH_ATTN", True) N_BATCH = _int_env("N_BATCH", 1024) N_THREADS = _int_env("N_THREADS", 6) N_CTX = _int_env("N_CTX", 1024) # For CPU builds, keep this at 0 N_GPU_LAYERS = _int_env("N_GPU_LAYERS", 0) # ---------------------------- # Cache dir (portable) # ---------------------------- # Colab Drive (optional) DRIVE_CACHE_DIR = "/content/drive/MyDrive/FADES_Models_Cache" # HF Spaces / Docker-friendly cache (your Dockerfile sets these to /data/...) HF_CACHE = ( os.getenv("HUGGINGFACE_HUB_CACHE") or (os.path.join(os.getenv("HF_HOME", "/data"), ".cache", "huggingface", "hub")) ) # Choose best available cache dir if os.path.exists("/content/drive"): CACHE_DIR = DRIVE_CACHE_DIR else: CACHE_DIR = HF_CACHE or "/tmp/hf_cache" try: os.makedirs(CACHE_DIR, exist_ok=True) except Exception: pass GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf") GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf") GEN_LOCK = asyncio.Lock() app = FastAPI(title="FADES Fallacy Detector API (Final)") # ============================ # CORS (for browser front-ends) # ============================ _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip() if _CORS_ORIGINS == "*" or not _CORS_ORIGINS: allow_origins = ["*"] else: allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=allow_origins, allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) ALLOWED_LABELS = [ "none", "faulty generalization", "false causality", "circular reasoning", "ad populum", "ad hominem", "fallacy of logic", "appeal to emotion", "false dilemma", "equivocation", "fallacy of extension", "fallacy of relevance", "fallacy of credibility", "miscellaneous", "intentional" ] LABEL_MAPPING = { "none": ["none"], "faulty": ["faulty generalization"], "false": ["false causality", "false dilemma"], "circular": ["circular reasoning"], "ad": ["ad populum", "ad hominem"], "fallacy": ["fallacy of logic", "extension", "relevance", "credibility"], "appeal": ["appeal to emotion"], "equivocation": ["equivocation"], "miscellaneous": ["miscellaneous"], "intentional": ["intentional"] } ANALYZE_SYS_PROMPT = """You are a logic expert. Detect logical fallacies in the given text. OUTPUT JSON ONLY. No markdown. No extra keys. No commentary outside JSON. IMPORTANT: - The text can contain MULTIPLE fallacies. Return ALL that apply. - If there are NO fallacies, set "has_fallacy": false and "fallacies": []. - "evidence_quotes" MUST be the SHORTEST exact span(s) from the input that justify the fallacy. Do NOT quote the whole text. Prefer 1 short quote; at most 3 quotes. LABELS: Use ONLY these labels: {labels} Do NOT invent labels. Do NOT output "none" as a fallacy item. RATIONALE QUALITY (VERY IMPORTANT): Each fallacy "rationale" MUST be directly tied to THIS input text, and MUST NOT be generic. Structure each rationale like this (2–3 sentences max): 1) Restate the specific claim from the input in your own words AND anchor it to the exact quote(s). 2) Explain why that claim matches the fallacy, referencing what is missing or what is assumed. 3) If relevant, mention a concrete cue in the text. OVERALL EXPLANATION (MUST reference the input): - First: a quick recap of the detected fallacies (1 short sentence). - Then: a general explanation of why these fallacies are risky IN THIS TEXT. - If no fallacy: briefly explain why the reasoning is acceptable / what would be needed to call it fallacious. CONFIDENCE: "confidence" is between 0.0 and 1.0. JSON SCHEMA: {{ "has_fallacy": boolean, "fallacies": [ {{ "type": string, "confidence": number, "evidence_quotes": [string], "rationale": string }} ], "overall_explanation": string }} EXAMPLES (style guide — copy this style): Input: "If we allow remote work, productivity will collapse and the company will fail." Output: {{ "has_fallacy": true, "fallacies": [{{ "type": "false causality", "confidence": 0.86, "evidence_quotes": ["If we allow remote work, productivity will collapse", "the company will fail"], "rationale": "The input implies that allowing remote work will directly cause productivity to collapse and lead to company failure (quotes: 'productivity will collapse', 'the company will fail') without supporting evidence. It treats a speculative outcome as a guaranteed causal chain, jumping from a policy change to extreme consequences." }}], "overall_explanation": "Recap: false causality. Risk: the text presents a shaky cause-and-effect chain as certain, which can push decisions based on fear rather than evidence and ignore alternative explanations." }} """ REWRITE_SYS_PROMPT = """You are a careful editor. Rewrite the text to REMOVE the logical fallacy while PRESERVING the original meaning as much as possible. Context: - Predicted fallacy type: {fallacy_type} - Rationale: {rationale} GOAL: - Keep the same overall intent, but soften / qualify claims so the reasoning is no longer fallacious. - Avoid absolute language ("always", "everyone", "no one") unless fully justified in the text. - Replace overgeneralizations with reasonable qualifiers ("some", "often", "can", "in some cases"). - If the issue is causality, add uncertainty or evidence requirements ("may contribute", "could be related", "without evidence we can't conclude"). - If the issue is a false dilemma, add alternatives and nuance. - If the issue is ad hominem / credibility attacks, remove personal attacks and focus on claims/evidence. OUTPUT JSON ONLY (no markdown): {{ "rewritten_text": string, "why_this_fix": string }} The "why_this_fix" must be short (1-2 sentences) and explain what you changed to remove the fallacy. EXAMPLE: Input idea: "All blond women are pretty." Output: {{ "rewritten_text": "Some blond women can be very pretty, but attractiveness varies from person to person.", "why_this_fix": "It removes the absolute generalization and replaces it with a qualified statement that doesn't stereotype an entire group." }} """ def analyze_alternatives(start_index: int, top_logprobs_list: List[Dict[str, float]]) -> Dict[str, float]: if start_index < 0 or start_index >= len(top_logprobs_list): return {} candidates = top_logprobs_list[start_index] distribution: Dict[str, float] = {} for token, logprob in candidates.items(): clean_tok = str(token).replace(" ", "").lower().strip() prob = math.exp(logprob) matched = False for key, group in LABEL_MAPPING.items(): if clean_tok.startswith(key): group_name = ( f"{key.capitalize()} ({'/'.join([g.split()[-1] for g in group])})" if len(group) > 1 else group[0].title() ) distribution[group_name] = distribution.get(group_name, 0.0) + prob matched = True break if not matched: distribution["_other_"] = distribution.get("_other_", 0.0) + prob return {k: round(v, 4) for k, v in distribution.items() if v > 0.001} def extract_label_info(target_label: str, tokens: List[str], logprobs: List[float], top_logprobs: List[Dict]) -> Dict: if not target_label: return {"conf": 0.0, "dist": {}} target_clean = target_label.lower().strip() current_text = "" start_index = -1 for i, token in enumerate(tokens): tok_str = str(token) if not isinstance(token, bytes) else token.decode("utf-8", errors="ignore") current_text += tok_str if target_clean in current_text.lower() and start_index == -1: start_index = max(0, i - 5) for j in range(start_index, i + 1): t_s = str(tokens[j]).lower() if target_clean and target_clean[0] in t_s: start_index = j break break conf = 0.0 dist: Dict[str, float] = {} if start_index != -1: valid = [ math.exp(logprobs[k]) for k in range(start_index, min(len(logprobs), start_index + 3)) if logprobs[k] is not None ] conf = round(sum(valid) / len(valid), 4) if valid else 0.0 if top_logprobs: dist = analyze_alternatives(start_index, top_logprobs) return {"conf": conf, "dist": dist} @lru_cache(maxsize=1) def get_model(): print("📦 Loading Model...") model_path = hf_hub_download( repo_id=GGUF_REPO_ID, filename=GGUF_FILENAME, cache_dir=CACHE_DIR, repo_type="model", ) # Try with flash_attn + gpu layers (if supported), otherwise fallback safely (CPU) try: llm = Llama( model_path=model_path, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, verbose=False, n_gpu_layers=N_GPU_LAYERS, flash_attn=USE_FLASH_ATTN, logits_all=ENABLE_FULL_CONFIDENCE, ) return llm except TypeError: # Older builds may not accept flash_attn llm = Llama( model_path=model_path, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, verbose=False, n_gpu_layers=0, logits_all=ENABLE_FULL_CONFIDENCE, ) return llm except Exception as e: print(f"❌ Error while loading model: {e}") raise class AnalyzeRequest(BaseModel): text: str max_new_tokens: int = 300 temperature: float = 0.1 class RewriteRequest(BaseModel): text: str fallacy_type: str rationale: str max_new_tokens: int = 300 @app.get("/health") def health(): get_model() return {"status": "ok"} @app.post("/analyze") async def analyze(req: AnalyzeRequest): llm = get_model() system_prompt = ANALYZE_SYS_PROMPT.format(labels=", ".join(ALLOWED_LABELS)) prompt = f"[INST] {system_prompt}\n\nINPUT TEXT:\n{req.text} [/INST]" req_logprobs = 20 if ENABLE_FULL_CONFIDENCE else None async with GEN_LOCK: start_time = time.time() output = llm( prompt, max_tokens=req.max_new_tokens, temperature=req.temperature, top_p=0.95, repeat_penalty=1.15, stop=["", "```"], echo=False, logprobs=req_logprobs, ) gen_time = time.time() - start_time raw_text = output["choices"][0]["text"] tokens = [] logprobs = [] top_logprobs = [] if ENABLE_FULL_CONFIDENCE and "logprobs" in output["choices"][0]: lp_data = output["choices"][0]["logprobs"] tokens = lp_data.get("tokens", []) logprobs = lp_data.get("token_logprobs", []) top_logprobs = lp_data.get("top_logprobs", []) parsed_obj = extract_json_obj_robust(raw_text) result_json: Dict[str, Any] = {} success = False technical_confidence = 0.0 label_distribution: Dict[str, float] = {} if parsed_obj is not None: # Enforce schema + clean common template artifacts result_json = sanitize_analyze_output(parsed_obj, req.text) success = True if result_json.get("has_fallacy") and result_json.get("fallacies"): for fallacy in result_json["fallacies"]: d_type = fallacy.get("type", "") if ENABLE_FULL_CONFIDENCE: info = extract_label_info(d_type, tokens, logprobs, top_logprobs) spec_conf = info["conf"] label_distribution = info["dist"] fallacy["technical_confidence"] = spec_conf fallacy["alternatives"] = label_distribution declared = fallacy.get("confidence", 0.8) fallacy["confidence"] = round((float(declared) + float(spec_conf)) / 2, 2) if technical_confidence == 0.0: technical_confidence = spec_conf else: if ENABLE_FULL_CONFIDENCE: info = extract_label_info("has_fallacy", tokens, logprobs, top_logprobs) label_distribution = info["dist"] else: result_json = {"error": "JSON Error", "raw": raw_text} success = False return { "ok": success, "result": result_json, "meta": { "tech_conf": technical_confidence, "distribution": label_distribution, "time": round(gen_time, 2), }, } @app.post("/rewrite") async def rewrite(req: RewriteRequest): llm = get_model() system_prompt = REWRITE_SYS_PROMPT.format(fallacy_type=req.fallacy_type, rationale=req.rationale) prompt = f"[INST] {system_prompt}\n\nTEXT TO FIX:\n{req.text} [/INST]" async with GEN_LOCK: output = llm( prompt, max_tokens=req.max_new_tokens, temperature=0.7, repeat_penalty=1.1, stop=["", "```"], ) try: parsed = extract_json_obj_robust(output["choices"][0]["text"]) if parsed is None: raise ValueError("json_parse_failed") res = parsed ok = True except Exception: res = {"raw": output["choices"][0]["text"]} ok = False return {"ok": ok, "result": res} if __name__ == "__main__": # Works both locally + HF Spaces port = _int_env("PORT", 7860) uvicorn.run(app, host="0.0.0.0", port=port)