fades-api / main.py
maxime-antoine-dev's picture
Update main.py
d9b52fd verified
import os
import json
import time
import math
import asyncio
from functools import lru_cache
from typing import Any, Dict, List
from utils import extract_json_obj_robust, sanitize_analyze_output
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
# ----------------------------
# Config (env overridable)
# ----------------------------
def _int_env(name: str, default: int) -> int:
try:
return int(os.getenv(name, str(default)))
except Exception:
return default
def _bool_env(name: str, default: bool) -> bool:
v = os.getenv(name, None)
if v is None:
return default
return v.strip().lower() in {"1", "true", "yes", "y", "on"}
ENABLE_FULL_CONFIDENCE = _bool_env("ENABLE_FULL_CONFIDENCE", True)
USE_FLASH_ATTN = _bool_env("USE_FLASH_ATTN", True)
N_BATCH = _int_env("N_BATCH", 1024)
N_THREADS = _int_env("N_THREADS", 6)
N_CTX = _int_env("N_CTX", 1024)
# For CPU builds, keep this at 0
N_GPU_LAYERS = _int_env("N_GPU_LAYERS", 0)
# ----------------------------
# Cache dir (portable)
# ----------------------------
# Colab Drive (optional)
DRIVE_CACHE_DIR = "/content/drive/MyDrive/FADES_Models_Cache"
# HF Spaces / Docker-friendly cache (your Dockerfile sets these to /data/...)
HF_CACHE = (
os.getenv("HUGGINGFACE_HUB_CACHE")
or (os.path.join(os.getenv("HF_HOME", "/data"), ".cache", "huggingface", "hub"))
)
# Choose best available cache dir
if os.path.exists("/content/drive"):
CACHE_DIR = DRIVE_CACHE_DIR
else:
CACHE_DIR = HF_CACHE or "/tmp/hf_cache"
try:
os.makedirs(CACHE_DIR, exist_ok=True)
except Exception:
pass
GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
GEN_LOCK = asyncio.Lock()
app = FastAPI(title="FADES Fallacy Detector API (Final)")
# ============================
# CORS (for browser front-ends)
# ============================
_CORS_ORIGINS = os.getenv("CORS_ALLOW_ORIGINS", "*").strip()
if _CORS_ORIGINS == "*" or not _CORS_ORIGINS:
allow_origins = ["*"]
else:
allow_origins = [o.strip() for o in _CORS_ORIGINS.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
ALLOWED_LABELS = [
"none", "faulty generalization", "false causality", "circular reasoning",
"ad populum", "ad hominem", "fallacy of logic", "appeal to emotion",
"false dilemma", "equivocation", "fallacy of extension",
"fallacy of relevance", "fallacy of credibility", "miscellaneous", "intentional"
]
LABEL_MAPPING = {
"none": ["none"],
"faulty": ["faulty generalization"],
"false": ["false causality", "false dilemma"],
"circular": ["circular reasoning"],
"ad": ["ad populum", "ad hominem"],
"fallacy": ["fallacy of logic", "extension", "relevance", "credibility"],
"appeal": ["appeal to emotion"],
"equivocation": ["equivocation"],
"miscellaneous": ["miscellaneous"],
"intentional": ["intentional"]
}
ANALYZE_SYS_PROMPT = """You are a logic expert. Detect logical fallacies in the given text.
OUTPUT JSON ONLY. No markdown. No extra keys. No commentary outside JSON.
IMPORTANT:
- The text can contain MULTIPLE fallacies. Return ALL that apply.
- If there are NO fallacies, set "has_fallacy": false and "fallacies": [].
- "evidence_quotes" MUST be the SHORTEST exact span(s) from the input that justify the fallacy.
Do NOT quote the whole text. Prefer 1 short quote; at most 3 quotes.
LABELS:
Use ONLY these labels: {labels}
Do NOT invent labels.
Do NOT output "none" as a fallacy item.
RATIONALE QUALITY (VERY IMPORTANT):
Each fallacy "rationale" MUST be directly tied to THIS input text, and MUST NOT be generic.
Structure each rationale like this (2–3 sentences max):
1) Restate the specific claim from the input in your own words AND anchor it to the exact quote(s).
2) Explain why that claim matches the fallacy, referencing what is missing or what is assumed.
3) If relevant, mention a concrete cue in the text.
OVERALL EXPLANATION (MUST reference the input):
- First: a quick recap of the detected fallacies (1 short sentence).
- Then: a general explanation of why these fallacies are risky IN THIS TEXT.
- If no fallacy: briefly explain why the reasoning is acceptable / what would be needed to call it fallacious.
CONFIDENCE:
"confidence" is between 0.0 and 1.0.
JSON SCHEMA:
{{
"has_fallacy": boolean,
"fallacies": [
{{
"type": string,
"confidence": number,
"evidence_quotes": [string],
"rationale": string
}}
],
"overall_explanation": string
}}
EXAMPLES (style guide — copy this style):
Input: "If we allow remote work, productivity will collapse and the company will fail."
Output:
{{
"has_fallacy": true,
"fallacies": [{{
"type": "false causality",
"confidence": 0.86,
"evidence_quotes": ["If we allow remote work, productivity will collapse", "the company will fail"],
"rationale": "The input implies that allowing remote work will directly cause productivity to collapse and lead to company failure (quotes: 'productivity will collapse', 'the company will fail') without supporting evidence. It treats a speculative outcome as a guaranteed causal chain, jumping from a policy change to extreme consequences."
}}],
"overall_explanation": "Recap: false causality. Risk: the text presents a shaky cause-and-effect chain as certain, which can push decisions based on fear rather than evidence and ignore alternative explanations."
}}
"""
REWRITE_SYS_PROMPT = """You are a careful editor. Rewrite the text to REMOVE the logical fallacy while PRESERVING the original meaning as much as possible.
Context:
- Predicted fallacy type: {fallacy_type}
- Rationale: {rationale}
GOAL:
- Keep the same overall intent, but soften / qualify claims so the reasoning is no longer fallacious.
- Avoid absolute language ("always", "everyone", "no one") unless fully justified in the text.
- Replace overgeneralizations with reasonable qualifiers ("some", "often", "can", "in some cases").
- If the issue is causality, add uncertainty or evidence requirements ("may contribute", "could be related", "without evidence we can't conclude").
- If the issue is a false dilemma, add alternatives and nuance.
- If the issue is ad hominem / credibility attacks, remove personal attacks and focus on claims/evidence.
OUTPUT JSON ONLY (no markdown):
{{
"rewritten_text": string,
"why_this_fix": string
}}
The "why_this_fix" must be short (1-2 sentences) and explain what you changed to remove the fallacy.
EXAMPLE:
Input idea: "All blond women are pretty."
Output:
{{
"rewritten_text": "Some blond women can be very pretty, but attractiveness varies from person to person.",
"why_this_fix": "It removes the absolute generalization and replaces it with a qualified statement that doesn't stereotype an entire group."
}}
"""
def analyze_alternatives(start_index: int, top_logprobs_list: List[Dict[str, float]]) -> Dict[str, float]:
if start_index < 0 or start_index >= len(top_logprobs_list):
return {}
candidates = top_logprobs_list[start_index]
distribution: Dict[str, float] = {}
for token, logprob in candidates.items():
clean_tok = str(token).replace(" ", "").lower().strip()
prob = math.exp(logprob)
matched = False
for key, group in LABEL_MAPPING.items():
if clean_tok.startswith(key):
group_name = (
f"{key.capitalize()} ({'/'.join([g.split()[-1] for g in group])})"
if len(group) > 1
else group[0].title()
)
distribution[group_name] = distribution.get(group_name, 0.0) + prob
matched = True
break
if not matched:
distribution["_other_"] = distribution.get("_other_", 0.0) + prob
return {k: round(v, 4) for k, v in distribution.items() if v > 0.001}
def extract_label_info(target_label: str, tokens: List[str], logprobs: List[float], top_logprobs: List[Dict]) -> Dict:
if not target_label:
return {"conf": 0.0, "dist": {}}
target_clean = target_label.lower().strip()
current_text = ""
start_index = -1
for i, token in enumerate(tokens):
tok_str = str(token) if not isinstance(token, bytes) else token.decode("utf-8", errors="ignore")
current_text += tok_str
if target_clean in current_text.lower() and start_index == -1:
start_index = max(0, i - 5)
for j in range(start_index, i + 1):
t_s = str(tokens[j]).lower()
if target_clean and target_clean[0] in t_s:
start_index = j
break
break
conf = 0.0
dist: Dict[str, float] = {}
if start_index != -1:
valid = [
math.exp(logprobs[k])
for k in range(start_index, min(len(logprobs), start_index + 3))
if logprobs[k] is not None
]
conf = round(sum(valid) / len(valid), 4) if valid else 0.0
if top_logprobs:
dist = analyze_alternatives(start_index, top_logprobs)
return {"conf": conf, "dist": dist}
@lru_cache(maxsize=1)
def get_model():
print("📦 Loading Model...")
model_path = hf_hub_download(
repo_id=GGUF_REPO_ID,
filename=GGUF_FILENAME,
cache_dir=CACHE_DIR,
repo_type="model",
)
# Try with flash_attn + gpu layers (if supported), otherwise fallback safely (CPU)
try:
llm = Llama(
model_path=model_path,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
verbose=False,
n_gpu_layers=N_GPU_LAYERS,
flash_attn=USE_FLASH_ATTN,
logits_all=ENABLE_FULL_CONFIDENCE,
)
return llm
except TypeError:
# Older builds may not accept flash_attn
llm = Llama(
model_path=model_path,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
verbose=False,
n_gpu_layers=0,
logits_all=ENABLE_FULL_CONFIDENCE,
)
return llm
except Exception as e:
print(f"❌ Error while loading model: {e}")
raise
class AnalyzeRequest(BaseModel):
text: str
max_new_tokens: int = 300
temperature: float = 0.1
class RewriteRequest(BaseModel):
text: str
fallacy_type: str
rationale: str
max_new_tokens: int = 300
@app.get("/health")
def health():
get_model()
return {"status": "ok"}
@app.post("/analyze")
async def analyze(req: AnalyzeRequest):
llm = get_model()
system_prompt = ANALYZE_SYS_PROMPT.format(labels=", ".join(ALLOWED_LABELS))
prompt = f"[INST] {system_prompt}\n\nINPUT TEXT:\n{req.text} [/INST]"
req_logprobs = 20 if ENABLE_FULL_CONFIDENCE else None
async with GEN_LOCK:
start_time = time.time()
output = llm(
prompt,
max_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=0.95,
repeat_penalty=1.15,
stop=["</s>", "```"],
echo=False,
logprobs=req_logprobs,
)
gen_time = time.time() - start_time
raw_text = output["choices"][0]["text"]
tokens = []
logprobs = []
top_logprobs = []
if ENABLE_FULL_CONFIDENCE and "logprobs" in output["choices"][0]:
lp_data = output["choices"][0]["logprobs"]
tokens = lp_data.get("tokens", [])
logprobs = lp_data.get("token_logprobs", [])
top_logprobs = lp_data.get("top_logprobs", [])
parsed_obj = extract_json_obj_robust(raw_text)
result_json: Dict[str, Any] = {}
success = False
technical_confidence = 0.0
label_distribution: Dict[str, float] = {}
if parsed_obj is not None:
# Enforce schema + clean common template artifacts
result_json = sanitize_analyze_output(parsed_obj, req.text)
success = True
if result_json.get("has_fallacy") and result_json.get("fallacies"):
for fallacy in result_json["fallacies"]:
d_type = fallacy.get("type", "")
if ENABLE_FULL_CONFIDENCE:
info = extract_label_info(d_type, tokens, logprobs, top_logprobs)
spec_conf = info["conf"]
label_distribution = info["dist"]
fallacy["technical_confidence"] = spec_conf
fallacy["alternatives"] = label_distribution
declared = fallacy.get("confidence", 0.8)
fallacy["confidence"] = round((float(declared) + float(spec_conf)) / 2, 2)
if technical_confidence == 0.0:
technical_confidence = spec_conf
else:
if ENABLE_FULL_CONFIDENCE:
info = extract_label_info("has_fallacy", tokens, logprobs, top_logprobs)
label_distribution = info["dist"]
else:
result_json = {"error": "JSON Error", "raw": raw_text}
success = False
return {
"ok": success,
"result": result_json,
"meta": {
"tech_conf": technical_confidence,
"distribution": label_distribution,
"time": round(gen_time, 2),
},
}
@app.post("/rewrite")
async def rewrite(req: RewriteRequest):
llm = get_model()
system_prompt = REWRITE_SYS_PROMPT.format(fallacy_type=req.fallacy_type, rationale=req.rationale)
prompt = f"[INST] {system_prompt}\n\nTEXT TO FIX:\n{req.text} [/INST]"
async with GEN_LOCK:
output = llm(
prompt,
max_tokens=req.max_new_tokens,
temperature=0.7,
repeat_penalty=1.1,
stop=["</s>", "```"],
)
try:
parsed = extract_json_obj_robust(output["choices"][0]["text"])
if parsed is None:
raise ValueError("json_parse_failed")
res = parsed
ok = True
except Exception:
res = {"raw": output["choices"][0]["text"]}
ok = False
return {"ok": ok, "result": res}
if __name__ == "__main__":
# Works both locally + HF Spaces
port = _int_env("PORT", 7860)
uvicorn.run(app, host="0.0.0.0", port=port)