Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import math | |
| import asyncio | |
| from functools import lru_cache | |
| from typing import Any, Dict, List | |
| from utils import extract_json_obj_robust, sanitize_analyze_output | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| # ---------------------------- | |
| # Config (env overridable) | |
| # ---------------------------- | |
| def _int_env(name: str, default: int) -> int: | |
| try: | |
| return int(os.getenv(name, str(default))) | |
| except Exception: | |
| return default | |
| def _bool_env(name: str, default: bool) -> bool: | |
| v = os.getenv(name, None) | |
| if v is None: | |
| return default | |
| return v.strip().lower() in {"1", "true", "yes", "y", "on"} | |
| ENABLE_FULL_CONFIDENCE = _bool_env("ENABLE_FULL_CONFIDENCE", True) | |
| USE_FLASH_ATTN = _bool_env("USE_FLASH_ATTN", True) | |
| N_BATCH = _int_env("N_BATCH", 1024) | |
| N_THREADS = _int_env("N_THREADS", 6) | |
| N_CTX = _int_env("N_CTX", 1024) | |
| # For CPU builds, keep this at 0 | |
| N_GPU_LAYERS = _int_env("N_GPU_LAYERS", 0) | |
| # ---------------------------- | |
| # Cache dir (portable) | |
| # ---------------------------- | |
| # Colab Drive (optional) | |
| DRIVE_CACHE_DIR = "/content/drive/MyDrive/FADES_Models_Cache" | |
| # HF Spaces / Docker-friendly cache (your Dockerfile sets these to /data/...) | |
| HF_CACHE = ( | |
| os.getenv("HUGGINGFACE_HUB_CACHE") | |
| or (os.path.join(os.getenv("HF_HOME", "/data"), ".cache", "huggingface", "hub")) | |
| ) | |
| # Choose best available cache dir | |
| if os.path.exists("/content/drive"): | |
| CACHE_DIR = DRIVE_CACHE_DIR | |
| else: | |
| CACHE_DIR = HF_CACHE or "/tmp/hf_cache" | |
| try: | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| except Exception: | |
| pass | |
| GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf") | |
| GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf") | |
| GEN_LOCK = asyncio.Lock() | |
| app = FastAPI(title="FADES Fallacy Detector API (Final)") | |
| # ============================ | |
| # CORS (for browser front-ends) | |
| # ============================ | |
| _CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip() | |
| if _CORS_ORIGINS == "*" or not _CORS_ORIGINS: | |
| allow_origins = ["*"] | |
| else: | |
| allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allow_origins, | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| ALLOWED_LABELS = [ | |
| "none", "faulty generalization", "false causality", "circular reasoning", | |
| "ad populum", "ad hominem", "fallacy of logic", "appeal to emotion", | |
| "false dilemma", "equivocation", "fallacy of extension", | |
| "fallacy of relevance", "fallacy of credibility", "miscellaneous", "intentional" | |
| ] | |
| LABEL_MAPPING = { | |
| "none": ["none"], | |
| "faulty": ["faulty generalization"], | |
| "false": ["false causality", "false dilemma"], | |
| "circular": ["circular reasoning"], | |
| "ad": ["ad populum", "ad hominem"], | |
| "fallacy": ["fallacy of logic", "extension", "relevance", "credibility"], | |
| "appeal": ["appeal to emotion"], | |
| "equivocation": ["equivocation"], | |
| "miscellaneous": ["miscellaneous"], | |
| "intentional": ["intentional"] | |
| } | |
| ANALYZE_SYS_PROMPT = """You are a logic expert. Detect logical fallacies in the given text. | |
| OUTPUT JSON ONLY. No markdown. No extra keys. No commentary outside JSON. | |
| IMPORTANT: | |
| - The text can contain MULTIPLE fallacies. Return ALL that apply. | |
| - If there are NO fallacies, set "has_fallacy": false and "fallacies": []. | |
| - "evidence_quotes" MUST be the SHORTEST exact span(s) from the input that justify the fallacy. | |
| Do NOT quote the whole text. Prefer 1 short quote; at most 3 quotes. | |
| LABELS: | |
| Use ONLY these labels: {labels} | |
| Do NOT invent labels. | |
| Do NOT output "none" as a fallacy item. | |
| RATIONALE QUALITY (VERY IMPORTANT): | |
| Each fallacy "rationale" MUST be directly tied to THIS input text, and MUST NOT be generic. | |
| Structure each rationale like this (2–3 sentences max): | |
| 1) Restate the specific claim from the input in your own words AND anchor it to the exact quote(s). | |
| 2) Explain why that claim matches the fallacy, referencing what is missing or what is assumed. | |
| 3) If relevant, mention a concrete cue in the text. | |
| OVERALL EXPLANATION (MUST reference the input): | |
| - First: a quick recap of the detected fallacies (1 short sentence). | |
| - Then: a general explanation of why these fallacies are risky IN THIS TEXT. | |
| - If no fallacy: briefly explain why the reasoning is acceptable / what would be needed to call it fallacious. | |
| CONFIDENCE: | |
| "confidence" is between 0.0 and 1.0. | |
| JSON SCHEMA: | |
| {{ | |
| "has_fallacy": boolean, | |
| "fallacies": [ | |
| {{ | |
| "type": string, | |
| "confidence": number, | |
| "evidence_quotes": [string], | |
| "rationale": string | |
| }} | |
| ], | |
| "overall_explanation": string | |
| }} | |
| EXAMPLES (style guide — copy this style): | |
| Input: "If we allow remote work, productivity will collapse and the company will fail." | |
| Output: | |
| {{ | |
| "has_fallacy": true, | |
| "fallacies": [{{ | |
| "type": "false causality", | |
| "confidence": 0.86, | |
| "evidence_quotes": ["If we allow remote work, productivity will collapse", "the company will fail"], | |
| "rationale": "The input implies that allowing remote work will directly cause productivity to collapse and lead to company failure (quotes: 'productivity will collapse', 'the company will fail') without supporting evidence. It treats a speculative outcome as a guaranteed causal chain, jumping from a policy change to extreme consequences." | |
| }}], | |
| "overall_explanation": "Recap: false causality. Risk: the text presents a shaky cause-and-effect chain as certain, which can push decisions based on fear rather than evidence and ignore alternative explanations." | |
| }} | |
| """ | |
| REWRITE_SYS_PROMPT = """You are a careful editor. Rewrite the text to REMOVE the logical fallacy while PRESERVING the original meaning as much as possible. | |
| Context: | |
| - Predicted fallacy type: {fallacy_type} | |
| - Rationale: {rationale} | |
| GOAL: | |
| - Keep the same overall intent, but soften / qualify claims so the reasoning is no longer fallacious. | |
| - Avoid absolute language ("always", "everyone", "no one") unless fully justified in the text. | |
| - Replace overgeneralizations with reasonable qualifiers ("some", "often", "can", "in some cases"). | |
| - If the issue is causality, add uncertainty or evidence requirements ("may contribute", "could be related", "without evidence we can't conclude"). | |
| - If the issue is a false dilemma, add alternatives and nuance. | |
| - If the issue is ad hominem / credibility attacks, remove personal attacks and focus on claims/evidence. | |
| OUTPUT JSON ONLY (no markdown): | |
| {{ | |
| "rewritten_text": string, | |
| "why_this_fix": string | |
| }} | |
| The "why_this_fix" must be short (1-2 sentences) and explain what you changed to remove the fallacy. | |
| EXAMPLE: | |
| Input idea: "All blond women are pretty." | |
| Output: | |
| {{ | |
| "rewritten_text": "Some blond women can be very pretty, but attractiveness varies from person to person.", | |
| "why_this_fix": "It removes the absolute generalization and replaces it with a qualified statement that doesn't stereotype an entire group." | |
| }} | |
| """ | |
| def analyze_alternatives(start_index: int, top_logprobs_list: List[Dict[str, float]]) -> Dict[str, float]: | |
| if start_index < 0 or start_index >= len(top_logprobs_list): | |
| return {} | |
| candidates = top_logprobs_list[start_index] | |
| distribution: Dict[str, float] = {} | |
| for token, logprob in candidates.items(): | |
| clean_tok = str(token).replace(" ", "").lower().strip() | |
| prob = math.exp(logprob) | |
| matched = False | |
| for key, group in LABEL_MAPPING.items(): | |
| if clean_tok.startswith(key): | |
| group_name = ( | |
| f"{key.capitalize()} ({'/'.join([g.split()[-1] for g in group])})" | |
| if len(group) > 1 | |
| else group[0].title() | |
| ) | |
| distribution[group_name] = distribution.get(group_name, 0.0) + prob | |
| matched = True | |
| break | |
| if not matched: | |
| distribution["_other_"] = distribution.get("_other_", 0.0) + prob | |
| return {k: round(v, 4) for k, v in distribution.items() if v > 0.001} | |
| def extract_label_info(target_label: str, tokens: List[str], logprobs: List[float], top_logprobs: List[Dict]) -> Dict: | |
| if not target_label: | |
| return {"conf": 0.0, "dist": {}} | |
| target_clean = target_label.lower().strip() | |
| current_text = "" | |
| start_index = -1 | |
| for i, token in enumerate(tokens): | |
| tok_str = str(token) if not isinstance(token, bytes) else token.decode("utf-8", errors="ignore") | |
| current_text += tok_str | |
| if target_clean in current_text.lower() and start_index == -1: | |
| start_index = max(0, i - 5) | |
| for j in range(start_index, i + 1): | |
| t_s = str(tokens[j]).lower() | |
| if target_clean and target_clean[0] in t_s: | |
| start_index = j | |
| break | |
| break | |
| conf = 0.0 | |
| dist: Dict[str, float] = {} | |
| if start_index != -1: | |
| valid = [ | |
| math.exp(logprobs[k]) | |
| for k in range(start_index, min(len(logprobs), start_index + 3)) | |
| if logprobs[k] is not None | |
| ] | |
| conf = round(sum(valid) / len(valid), 4) if valid else 0.0 | |
| if top_logprobs: | |
| dist = analyze_alternatives(start_index, top_logprobs) | |
| return {"conf": conf, "dist": dist} | |
| def get_model(): | |
| print("📦 Loading Model...") | |
| model_path = hf_hub_download( | |
| repo_id=GGUF_REPO_ID, | |
| filename=GGUF_FILENAME, | |
| cache_dir=CACHE_DIR, | |
| repo_type="model", | |
| ) | |
| # Try with flash_attn + gpu layers (if supported), otherwise fallback safely (CPU) | |
| try: | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| verbose=False, | |
| n_gpu_layers=N_GPU_LAYERS, | |
| flash_attn=USE_FLASH_ATTN, | |
| logits_all=ENABLE_FULL_CONFIDENCE, | |
| ) | |
| return llm | |
| except TypeError: | |
| # Older builds may not accept flash_attn | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| verbose=False, | |
| n_gpu_layers=0, | |
| logits_all=ENABLE_FULL_CONFIDENCE, | |
| ) | |
| return llm | |
| except Exception as e: | |
| print(f"❌ Error while loading model: {e}") | |
| raise | |
| class AnalyzeRequest(BaseModel): | |
| text: str | |
| max_new_tokens: int = 300 | |
| temperature: float = 0.1 | |
| class RewriteRequest(BaseModel): | |
| text: str | |
| fallacy_type: str | |
| rationale: str | |
| max_new_tokens: int = 300 | |
| def health(): | |
| get_model() | |
| return {"status": "ok"} | |
| async def analyze(req: AnalyzeRequest): | |
| llm = get_model() | |
| system_prompt = ANALYZE_SYS_PROMPT.format(labels=", ".join(ALLOWED_LABELS)) | |
| prompt = f"[INST] {system_prompt}\n\nINPUT TEXT:\n{req.text} [/INST]" | |
| req_logprobs = 20 if ENABLE_FULL_CONFIDENCE else None | |
| async with GEN_LOCK: | |
| start_time = time.time() | |
| output = llm( | |
| prompt, | |
| max_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=0.95, | |
| repeat_penalty=1.15, | |
| stop=["</s>", "```"], | |
| echo=False, | |
| logprobs=req_logprobs, | |
| ) | |
| gen_time = time.time() - start_time | |
| raw_text = output["choices"][0]["text"] | |
| tokens = [] | |
| logprobs = [] | |
| top_logprobs = [] | |
| if ENABLE_FULL_CONFIDENCE and "logprobs" in output["choices"][0]: | |
| lp_data = output["choices"][0]["logprobs"] | |
| tokens = lp_data.get("tokens", []) | |
| logprobs = lp_data.get("token_logprobs", []) | |
| top_logprobs = lp_data.get("top_logprobs", []) | |
| parsed_obj = extract_json_obj_robust(raw_text) | |
| result_json: Dict[str, Any] = {} | |
| success = False | |
| technical_confidence = 0.0 | |
| label_distribution: Dict[str, float] = {} | |
| if parsed_obj is not None: | |
| # Enforce schema + clean common template artifacts | |
| result_json = sanitize_analyze_output(parsed_obj, req.text) | |
| success = True | |
| if result_json.get("has_fallacy") and result_json.get("fallacies"): | |
| for fallacy in result_json["fallacies"]: | |
| d_type = fallacy.get("type", "") | |
| if ENABLE_FULL_CONFIDENCE: | |
| info = extract_label_info(d_type, tokens, logprobs, top_logprobs) | |
| spec_conf = info["conf"] | |
| label_distribution = info["dist"] | |
| fallacy["technical_confidence"] = spec_conf | |
| fallacy["alternatives"] = label_distribution | |
| declared = fallacy.get("confidence", 0.8) | |
| fallacy["confidence"] = round((float(declared) + float(spec_conf)) / 2, 2) | |
| if technical_confidence == 0.0: | |
| technical_confidence = spec_conf | |
| else: | |
| if ENABLE_FULL_CONFIDENCE: | |
| info = extract_label_info("has_fallacy", tokens, logprobs, top_logprobs) | |
| label_distribution = info["dist"] | |
| else: | |
| result_json = {"error": "JSON Error", "raw": raw_text} | |
| success = False | |
| return { | |
| "ok": success, | |
| "result": result_json, | |
| "meta": { | |
| "tech_conf": technical_confidence, | |
| "distribution": label_distribution, | |
| "time": round(gen_time, 2), | |
| }, | |
| } | |
| async def rewrite(req: RewriteRequest): | |
| llm = get_model() | |
| system_prompt = REWRITE_SYS_PROMPT.format(fallacy_type=req.fallacy_type, rationale=req.rationale) | |
| prompt = f"[INST] {system_prompt}\n\nTEXT TO FIX:\n{req.text} [/INST]" | |
| async with GEN_LOCK: | |
| output = llm( | |
| prompt, | |
| max_tokens=req.max_new_tokens, | |
| temperature=0.7, | |
| repeat_penalty=1.1, | |
| stop=["</s>", "```"], | |
| ) | |
| try: | |
| parsed = extract_json_obj_robust(output["choices"][0]["text"]) | |
| if parsed is None: | |
| raise ValueError("json_parse_failed") | |
| res = parsed | |
| ok = True | |
| except Exception: | |
| res = {"raw": output["choices"][0]["text"]} | |
| ok = False | |
| return {"ok": ok, "result": res} | |
| if __name__ == "__main__": | |
| # Works both locally + HF Spaces | |
| port = _int_env("PORT", 7860) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |