maxime-antoine-dev commited on
Commit
5a8ecdf
·
1 Parent(s): 6e3fd51
Files changed (3) hide show
  1. Dockerfile +10 -0
  2. main.py +229 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+ COPY requirements.txt .
5
+ RUN pip install --no-cache-dir -r requirements.txt
6
+
7
+ COPY . .
8
+ EXPOSE 7860
9
+
10
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import json
4
+ from typing import Any, Dict, Optional
5
+ from functools import lru_cache
6
+ import asyncio
7
+
8
+ import torch
9
+ from fastapi import FastAPI
10
+ from pydantic import BaseModel
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from peft import PeftModel
13
+
14
+ # ----------------------------
15
+ # Config (CPU-friendly defaults)
16
+ # ----------------------------
17
+ BASE_MODEL = os.getenv("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
18
+ ADAPTER_REPO = os.getenv("ADAPTER_REPO", "maxime-antoine-dev/fades")
19
+
20
+ # Keep smaller on CPU to stay usable
21
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "896"))
22
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "140"))
23
+
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ # One request at a time on CPU Spaces (prevents OOM / huge latency spikes)
27
+ GEN_LOCK = asyncio.Lock()
28
+
29
+ # ----------------------------
30
+ # Prompt (aligned with training, but compact for CPU)
31
+ # ----------------------------
32
+ ALLOWED_LABELS = [
33
+ "none","faulty generalization","false causality","circular reasoning","ad populum","ad hominem",
34
+ "fallacy of logic","appeal to emotion","false dilemma","equivocation","fallacy of extension",
35
+ "fallacy of relevance","fallacy of credibility","miscellaneous","intentional",
36
+ ]
37
+
38
+ def labels_block_compact() -> str:
39
+ # Compact list (removes long hints to reduce prompt tokens on CPU)
40
+ return "\n".join([f'- "{k}"' for k in ALLOWED_LABELS])
41
+
42
+ INSTRUCTION = """You are a logical fallacy detection assistant.
43
+
44
+ You MUST choose labels ONLY from this list (use the exact string):
45
+ {labels_list}
46
+
47
+ Return ONLY ONE valid JSON object with this schema:
48
+ {{
49
+ "has_fallacy": boolean,
50
+ "fallacies": [
51
+ {{
52
+ "type": string,
53
+ "confidence": number, // 0.0..1.0
54
+ "evidence_quotes": [string], // exact substring(s) copied from the input text
55
+ "rationale": string // specific to this fallacy + quote
56
+ }}
57
+ ],
58
+ "overall_explanation": string // short summary across the whole input
59
+ }}
60
+
61
+ Hard rules:
62
+ - Output ONLY the JSON object. No markdown. No extra text.
63
+ - Produce exactly ONE JSON object, then STOP. Do NOT repeat the input. Do NOT create new examples.
64
+ - evidence_quotes MUST be exact substrings from the input text.
65
+ - If has_fallacy=false:
66
+ - fallacies MUST be []
67
+ - overall_explanation MUST explicitly say there is no fallacy
68
+ - overall_explanation MUST NOT mention any fallacy label/category names.
69
+ - If has_fallacy=true:
70
+ - fallacies MUST contain at least 1 item
71
+ - EACH fallacies[i].type MUST be one of the allowed labels (NOT a synonym)
72
+ - overall_explanation may summarize the detected fallacy(ies).
73
+ """
74
+
75
+ def build_prompt(tokenizer: AutoTokenizer, text: str) -> str:
76
+ instruction = INSTRUCTION.format(labels_list=labels_block_compact())
77
+ messages = [
78
+ {"role": "system", "content": "You are a careful JSON-only assistant."},
79
+ {"role": "user", "content": f"{instruction}\n\nTEXT:\n{text}\n\nJSON:"},
80
+ ]
81
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
+
83
+ # ----------------------------
84
+ # JSON extraction + stop guard
85
+ # ----------------------------
86
+ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
87
+ start = s.find("{")
88
+ if start == -1:
89
+ return None
90
+ end = s.rfind("}")
91
+ if end == -1 or end <= start:
92
+ return None
93
+ cand = s[start:end + 1].strip()
94
+ try:
95
+ return json.loads(cand)
96
+ except Exception:
97
+ return None
98
+
99
+ def stop_at_complete_json(text: str) -> Optional[str]:
100
+ start = text.find("{")
101
+ if start == -1:
102
+ return None
103
+
104
+ depth = 0
105
+ in_str = False
106
+ esc = False
107
+
108
+ for i in range(start, len(text)):
109
+ ch = text[i]
110
+ if in_str:
111
+ if esc:
112
+ esc = False
113
+ elif ch == "\\":
114
+ esc = True
115
+ elif ch == '"':
116
+ in_str = False
117
+ continue
118
+
119
+ if ch == '"':
120
+ in_str = True
121
+ continue
122
+ if ch == "{":
123
+ depth += 1
124
+ elif ch == "}":
125
+ depth -= 1
126
+ if depth == 0:
127
+ return text[start:i + 1]
128
+ return None
129
+
130
+ # ----------------------------
131
+ # Load model (tries CPU 8-bit if available; falls back to FP32)
132
+ # ----------------------------
133
+ def load_model() -> tuple[AutoTokenizer, Any]:
134
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO, use_fast=True)
135
+ if tokenizer.pad_token is None:
136
+ tokenizer.pad_token = tokenizer.eos_token
137
+
138
+ # Try CPU quantization to reduce RAM; fallback if not supported
139
+ try:
140
+ import bitsandbytes # noqa: F401
141
+
142
+ base = AutoModelForCausalLM.from_pretrained(
143
+ BASE_MODEL,
144
+ device_map="auto", # CPU on Spaces free
145
+ load_in_8bit=True,
146
+ )
147
+ except Exception:
148
+ base = AutoModelForCausalLM.from_pretrained(
149
+ BASE_MODEL,
150
+ device_map=None,
151
+ torch_dtype=torch.float32,
152
+ )
153
+
154
+ model = PeftModel.from_pretrained(base, ADAPTER_REPO)
155
+ model.eval()
156
+ return tokenizer, model
157
+
158
+ # Keep them global, but load on startup for clearer errors/logs
159
+ tokenizer: Optional[AutoTokenizer] = None
160
+ model: Optional[Any] = None
161
+
162
+ # ----------------------------
163
+ # FastAPI
164
+ # ----------------------------
165
+ app = FastAPI(title="FADES Fallacy Detector")
166
+
167
+ class AnalyzeRequest(BaseModel):
168
+ text: str
169
+ max_new_tokens: int = MAX_NEW_TOKENS
170
+
171
+ @app.get("/health")
172
+ def health():
173
+ return {
174
+ "ok": True,
175
+ "device": DEVICE,
176
+ "base_model": BASE_MODEL,
177
+ "adapter_repo": ADAPTER_REPO,
178
+ "model_loaded": model is not None and tokenizer is not None,
179
+ }
180
+
181
+ @app.on_event("startup")
182
+ def _startup():
183
+ global tokenizer, model
184
+ tokenizer, model = load_model()
185
+
186
+ @lru_cache(maxsize=256)
187
+ def _cached_generate(text: str, max_new_tokens: int) -> Dict[str, Any]:
188
+ assert tokenizer is not None and model is not None
189
+
190
+ prompt = build_prompt(tokenizer, text)
191
+
192
+ inputs = tokenizer(
193
+ prompt,
194
+ return_tensors="pt",
195
+ truncation=True,
196
+ max_length=MAX_INPUT_TOKENS,
197
+ )
198
+
199
+ # move to device if CUDA exists
200
+ if DEVICE == "cuda":
201
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
202
+
203
+ with torch.no_grad():
204
+ out = model.generate(
205
+ **inputs,
206
+ max_new_tokens=int(max_new_tokens),
207
+ do_sample=False,
208
+ temperature=0.0,
209
+ use_cache=True,
210
+ eos_token_id=tokenizer.eos_token_id,
211
+ pad_token_id=tokenizer.pad_token_id,
212
+ )
213
+
214
+ decoded = tokenizer.decode(out[0], skip_special_tokens=True)
215
+
216
+ cut = stop_at_complete_json(decoded)
217
+ decoded_cut = cut if cut is not None else decoded
218
+
219
+ obj = extract_first_json_obj(decoded_cut)
220
+ if obj is None:
221
+ return {"ok": False, "raw": decoded_cut}
222
+
223
+ return {"ok": True, "result": obj}
224
+
225
+ @app.post("/analyze")
226
+ async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
227
+ # One-at-a-time generation on CPU (prevents stalls/OOM)
228
+ async with GEN_LOCK:
229
+ return _cached_generate(req.text, int(req.max_new_tokens))
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ peft
6
+ accelerate
7
+ sentencepiece
8
+ safetensors
9
+ huggingface_hub