maxime-antoine-dev commited on
Commit
e19317d
·
1 Parent(s): 5a8ecdf

Updated compat with quantized model

Browse files
Files changed (3) hide show
  1. Dockerfile +15 -4
  2. main.py +99 -90
  3. requirements.txt +4 -9
Dockerfile CHANGED
@@ -1,10 +1,21 @@
1
  FROM python:3.11-slim
2
 
 
 
 
 
 
3
  WORKDIR /app
4
- COPY requirements.txt .
5
- RUN pip install --no-cache-dir -r requirements.txt
6
 
7
- COPY . .
8
- EXPOSE 7860
 
 
 
 
 
9
 
 
 
 
10
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.11-slim
2
 
3
+ ENV PYTHONUNBUFFERED=1 \
4
+ PIP_NO_CACHE_DIR=1 \
5
+ HF_HOME=/data/.huggingface \
6
+ HUGGINGFACE_HUB_CACHE=/data/.cache/huggingface/hub
7
+
8
  WORKDIR /app
 
 
9
 
10
+ # Optional but safer (some environments may need build tools)
11
+ RUN apt-get update && apt-get install -y --no-install-recommends \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ COPY requirements.txt /app/requirements.txt
16
+ RUN pip install -r /app/requirements.txt
17
 
18
+ COPY main.py /app/main.py
19
+
20
+ EXPOSE 7860
21
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py CHANGED
@@ -1,42 +1,57 @@
1
  # main.py
2
  import os
3
  import json
 
 
4
  from typing import Any, Dict, Optional
5
  from functools import lru_cache
6
- import asyncio
7
 
8
- import torch
9
  from fastapi import FastAPI
10
  from pydantic import BaseModel
11
- from transformers import AutoModelForCausalLM, AutoTokenizer
12
- from peft import PeftModel
13
 
14
  # ----------------------------
15
- # Config (CPU-friendly defaults)
16
  # ----------------------------
17
- BASE_MODEL = os.getenv("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
18
- ADAPTER_REPO = os.getenv("ADAPTER_REPO", "maxime-antoine-dev/fades")
19
 
20
- # Keep smaller on CPU to stay usable
21
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "896"))
22
- MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "140"))
 
23
 
24
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
25
 
26
- # One request at a time on CPU Spaces (prevents OOM / huge latency spikes)
27
  GEN_LOCK = asyncio.Lock()
28
 
29
  # ----------------------------
30
- # Prompt (aligned with training, but compact for CPU)
31
  # ----------------------------
32
  ALLOWED_LABELS = [
33
- "none","faulty generalization","false causality","circular reasoning","ad populum","ad hominem",
34
- "fallacy of logic","appeal to emotion","false dilemma","equivocation","fallacy of extension",
35
- "fallacy of relevance","fallacy of credibility","miscellaneous","intentional",
 
 
 
 
 
 
 
 
 
 
 
 
36
  ]
37
 
38
  def labels_block_compact() -> str:
39
- # Compact list (removes long hints to reduce prompt tokens on CPU)
40
  return "\n".join([f'- "{k}"' for k in ALLOWED_LABELS])
41
 
42
  INSTRUCTION = """You are a logical fallacy detection assistant.
@@ -60,7 +75,7 @@ Return ONLY ONE valid JSON object with this schema:
60
 
61
  Hard rules:
62
  - Output ONLY the JSON object. No markdown. No extra text.
63
- - Produce exactly ONE JSON object, then STOP. Do NOT repeat the input. Do NOT create new examples.
64
  - evidence_quotes MUST be exact substrings from the input text.
65
  - If has_fallacy=false:
66
  - fallacies MUST be []
@@ -69,19 +84,19 @@ Hard rules:
69
  - If has_fallacy=true:
70
  - fallacies MUST contain at least 1 item
71
  - EACH fallacies[i].type MUST be one of the allowed labels (NOT a synonym)
72
- - overall_explanation may summarize the detected fallacy(ies).
73
  """
74
 
75
- def build_prompt(tokenizer: AutoTokenizer, text: str) -> str:
 
 
76
  instruction = INSTRUCTION.format(labels_list=labels_block_compact())
77
- messages = [
78
- {"role": "system", "content": "You are a careful JSON-only assistant."},
79
  {"role": "user", "content": f"{instruction}\n\nTEXT:\n{text}\n\nJSON:"},
80
  ]
81
- return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
 
83
  # ----------------------------
84
- # JSON extraction + stop guard
85
  # ----------------------------
86
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
87
  start = s.find("{")
@@ -90,7 +105,7 @@ def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
90
  end = s.rfind("}")
91
  if end == -1 or end <= start:
92
  return None
93
- cand = s[start:end + 1].strip()
94
  try:
95
  return json.loads(cand)
96
  except Exception:
@@ -124,106 +139,100 @@ def stop_at_complete_json(text: str) -> Optional[str]:
124
  elif ch == "}":
125
  depth -= 1
126
  if depth == 0:
127
- return text[start:i + 1]
128
  return None
129
 
130
  # ----------------------------
131
- # Load model (tries CPU 8-bit if available; falls back to FP32)
132
  # ----------------------------
133
- def load_model() -> tuple[AutoTokenizer, Any]:
134
- tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO, use_fast=True)
135
- if tokenizer.pad_token is None:
136
- tokenizer.pad_token = tokenizer.eos_token
137
-
138
- # Try CPU quantization to reduce RAM; fallback if not supported
139
- try:
140
- import bitsandbytes # noqa: F401
141
 
142
- base = AutoModelForCausalLM.from_pretrained(
143
- BASE_MODEL,
144
- device_map="auto", # CPU on Spaces free
145
- load_in_8bit=True,
146
- )
147
- except Exception:
148
- base = AutoModelForCausalLM.from_pretrained(
149
- BASE_MODEL,
150
- device_map=None,
151
- torch_dtype=torch.float32,
152
- )
153
 
154
- model = PeftModel.from_pretrained(base, ADAPTER_REPO)
155
- model.eval()
156
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- # Keep them global, but load on startup for clearer errors/logs
159
- tokenizer: Optional[AutoTokenizer] = None
160
- model: Optional[Any] = None
 
161
 
162
  # ----------------------------
163
  # FastAPI
164
  # ----------------------------
165
- app = FastAPI(title="FADES Fallacy Detector")
166
 
167
  class AnalyzeRequest(BaseModel):
168
  text: str
169
- max_new_tokens: int = MAX_NEW_TOKENS
 
 
170
 
171
  @app.get("/health")
172
  def health():
173
  return {
174
  "ok": True,
175
- "device": DEVICE,
176
- "base_model": BASE_MODEL,
177
- "adapter_repo": ADAPTER_REPO,
178
- "model_loaded": model is not None and tokenizer is not None,
 
 
 
 
179
  }
180
 
181
  @app.on_event("startup")
182
  def _startup():
183
- global tokenizer, model
184
- tokenizer, model = load_model()
 
185
 
186
  @lru_cache(maxsize=256)
187
- def _cached_generate(text: str, max_new_tokens: int) -> Dict[str, Any]:
188
- assert tokenizer is not None and model is not None
189
 
190
- prompt = build_prompt(tokenizer, text)
191
 
192
- inputs = tokenizer(
193
- prompt,
194
- return_tensors="pt",
195
- truncation=True,
196
- max_length=MAX_INPUT_TOKENS,
 
197
  )
198
 
199
- # move to device if CUDA exists
200
- if DEVICE == "cuda":
201
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
202
-
203
- with torch.no_grad():
204
- out = model.generate(
205
- **inputs,
206
- max_new_tokens=int(max_new_tokens),
207
- do_sample=False,
208
- temperature=0.0,
209
- use_cache=True,
210
- eos_token_id=tokenizer.eos_token_id,
211
- pad_token_id=tokenizer.pad_token_id,
212
- )
213
-
214
- decoded = tokenizer.decode(out[0], skip_special_tokens=True)
215
 
216
- cut = stop_at_complete_json(decoded)
217
- decoded_cut = cut if cut is not None else decoded
218
 
219
- obj = extract_first_json_obj(decoded_cut)
220
  if obj is None:
221
- return {"ok": False, "raw": decoded_cut}
222
 
223
  return {"ok": True, "result": obj}
224
 
225
  @app.post("/analyze")
226
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
227
- # One-at-a-time generation on CPU (prevents stalls/OOM)
228
  async with GEN_LOCK:
229
- return _cached_generate(req.text, int(req.max_new_tokens))
 
1
  # main.py
2
  import os
3
  import json
4
+ import time
5
+ import asyncio
6
  from typing import Any, Dict, Optional
7
  from functools import lru_cache
 
8
 
 
9
  from fastapi import FastAPI
10
  from pydantic import BaseModel
11
+ from huggingface_hub import hf_hub_download
12
+ from llama_cpp import Llama
13
 
14
  # ----------------------------
15
+ # Config
16
  # ----------------------------
17
+ GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "maxime-antoine-dev/fades-mistral-v02-gguf")
18
+ GGUF_FILENAME = os.getenv("GGUF_FILENAME", "mistral_v02_fades.Q4_K_M.gguf")
19
 
20
+ # llama.cpp params (CPU Space)
21
+ N_CTX = int(os.getenv("N_CTX", "2048"))
22
+ N_THREADS = int(os.getenv("N_THREADS", str(max(1, (os.cpu_count() or 2) - 1))))
23
+ N_BATCH = int(os.getenv("N_BATCH", "256"))
24
 
25
+ # generation defaults
26
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS", "180"))
27
+ TEMPERATURE_DEFAULT = float(os.getenv("TEMPERATURE", "0.0"))
28
+ TOP_P_DEFAULT = float(os.getenv("TOP_P", "0.95"))
29
 
30
+ # One request at a time on CPU (prevents stalls / extreme latency)
31
  GEN_LOCK = asyncio.Lock()
32
 
33
  # ----------------------------
34
+ # Prompt (aligned with your training target)
35
  # ----------------------------
36
  ALLOWED_LABELS = [
37
+ "none",
38
+ "faulty generalization",
39
+ "false causality",
40
+ "circular reasoning",
41
+ "ad populum",
42
+ "ad hominem",
43
+ "fallacy of logic",
44
+ "appeal to emotion",
45
+ "false dilemma",
46
+ "equivocation",
47
+ "fallacy of extension",
48
+ "fallacy of relevance",
49
+ "fallacy of credibility",
50
+ "miscellaneous",
51
+ "intentional",
52
  ]
53
 
54
  def labels_block_compact() -> str:
 
55
  return "\n".join([f'- "{k}"' for k in ALLOWED_LABELS])
56
 
57
  INSTRUCTION = """You are a logical fallacy detection assistant.
 
75
 
76
  Hard rules:
77
  - Output ONLY the JSON object. No markdown. No extra text.
78
+ - Produce exactly ONE JSON object, then STOP.
79
  - evidence_quotes MUST be exact substrings from the input text.
80
  - If has_fallacy=false:
81
  - fallacies MUST be []
 
84
  - If has_fallacy=true:
85
  - fallacies MUST contain at least 1 item
86
  - EACH fallacies[i].type MUST be one of the allowed labels (NOT a synonym)
 
87
  """
88
 
89
+ SYSTEM_PROMPT = "You are a careful JSON-only assistant. Output only JSON."
90
+
91
+ def build_messages(text: str) -> list[dict]:
92
  instruction = INSTRUCTION.format(labels_list=labels_block_compact())
93
+ return [
94
+ {"role": "system", "content": SYSTEM_PROMPT},
95
  {"role": "user", "content": f"{instruction}\n\nTEXT:\n{text}\n\nJSON:"},
96
  ]
 
97
 
98
  # ----------------------------
99
+ # Robust JSON extraction
100
  # ----------------------------
101
  def extract_first_json_obj(s: str) -> Optional[Dict[str, Any]]:
102
  start = s.find("{")
 
105
  end = s.rfind("}")
106
  if end == -1 or end <= start:
107
  return None
108
+ cand = s[start : end + 1].strip()
109
  try:
110
  return json.loads(cand)
111
  except Exception:
 
139
  elif ch == "}":
140
  depth -= 1
141
  if depth == 0:
142
+ return text[start : i + 1]
143
  return None
144
 
145
  # ----------------------------
146
+ # Load GGUF model (global)
147
  # ----------------------------
148
+ llm: Optional[Llama] = None
149
+ model_path: Optional[str] = None
 
 
 
 
 
 
150
 
151
+ def load_llama() -> tuple[str, Llama]:
152
+ global model_path
 
 
 
 
 
 
 
 
 
153
 
154
+ t0 = time.time()
155
+ mp = hf_hub_download(
156
+ repo_id=GGUF_REPO_ID,
157
+ filename=GGUF_FILENAME,
158
+ token=os.getenv("HF_TOKEN"), # optional (only if repo is private)
159
+ )
160
+ t1 = time.time()
161
+
162
+ # CPU Space -> n_gpu_layers = 0
163
+ llama = Llama(
164
+ model_path=mp,
165
+ n_ctx=N_CTX,
166
+ n_threads=N_THREADS,
167
+ n_batch=N_BATCH,
168
+ n_gpu_layers=0,
169
+ verbose=True,
170
+ )
171
+ t2 = time.time()
172
 
173
+ print(f"✅ GGUF downloaded: {mp} ({t1 - t0:.1f}s)")
174
+ print(f"✅ Model loaded: ({t2 - t1:.1f}s) n_ctx={N_CTX} threads={N_THREADS} batch={N_BATCH}")
175
+ model_path = mp
176
+ return mp, llama
177
 
178
  # ----------------------------
179
  # FastAPI
180
  # ----------------------------
181
+ app = FastAPI(title="FADES Fallacy Detector (GGUF / llama.cpp)")
182
 
183
  class AnalyzeRequest(BaseModel):
184
  text: str
185
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT
186
+ temperature: float = TEMPERATURE_DEFAULT
187
+ top_p: float = TOP_P_DEFAULT
188
 
189
  @app.get("/health")
190
  def health():
191
  return {
192
  "ok": True,
193
+ "engine": "llama.cpp (llama-cpp-python)",
194
+ "gguf_repo": GGUF_REPO_ID,
195
+ "gguf_filename": GGUF_FILENAME,
196
+ "model_loaded": llm is not None,
197
+ "model_path": model_path,
198
+ "n_ctx": N_CTX,
199
+ "n_threads": N_THREADS,
200
+ "n_batch": N_BATCH,
201
  }
202
 
203
  @app.on_event("startup")
204
  def _startup():
205
+ global llm
206
+ _, llm_loaded = load_llama()
207
+ llm = llm_loaded
208
 
209
  @lru_cache(maxsize=256)
210
+ def _cached_generate(text: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]:
211
+ assert llm is not None
212
 
213
+ messages = build_messages(text)
214
 
215
+ out = llm.create_chat_completion(
216
+ messages=messages,
217
+ max_tokens=int(max_new_tokens),
218
+ temperature=float(temperature),
219
+ top_p=float(top_p),
220
+ stream=False,
221
  )
222
 
223
+ raw = out["choices"][0]["message"]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ cut = stop_at_complete_json(raw)
226
+ raw_cut = cut if cut is not None else raw
227
 
228
+ obj = extract_first_json_obj(raw_cut)
229
  if obj is None:
230
+ return {"ok": False, "raw": raw_cut}
231
 
232
  return {"ok": True, "result": obj}
233
 
234
  @app.post("/analyze")
235
  async def analyze(req: AnalyzeRequest) -> Dict[str, Any]:
236
+ # CPU: serialize requests to keep stable latency
237
  async with GEN_LOCK:
238
+ return _cached_generate(req.text, int(req.max_new_tokens), float(req.temperature), float(req.top_p))
requirements.txt CHANGED
@@ -1,9 +1,4 @@
1
- fastapi
2
- uvicorn
3
- torch
4
- transformers
5
- peft
6
- accelerate
7
- sentencepiece
8
- safetensors
9
- huggingface_hub
 
1
+ fastapi>=0.110
2
+ uvicorn[standard]>=0.27
3
+ huggingface_hub>=0.23
4
+ llama-cpp-python>=0.2.90