Ryanfafa commited on
Commit
188f4e4
·
verified ·
1 Parent(s): 8c5b9b8

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +310 -129
rag_engine.py CHANGED
@@ -1,58 +1,79 @@
1
  """
2
- RAG Engine - Memory optimized for HuggingFace free tier
3
- Embeddings : all-MiniLM-L6-v2 (CPU, ~90MB)
4
- Vector DB : ChromaDB (local)
5
- LLM : HuggingFace Router API with correct provider suffixes
6
  """
7
 
8
  import os
9
  import re
 
10
  import json
11
  import time
12
  import tempfile
13
  import requests
14
- from typing import Tuple, List
 
 
15
 
16
  from chromadb.config import Settings
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
  from langchain_community.vectorstores import Chroma
19
- from langchain_community.document_loaders import PyPDFLoader, TextLoader
20
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
21
  import monitor
22
 
 
 
 
 
23
  EMBED_MODEL = "all-MiniLM-L6-v2"
24
  CHUNK_SIZE = 600
25
  CHUNK_OVERLAP = 100
26
- TOP_K = 3
27
- COLLECTION_NAME = "docmind_collection"
28
  CHROMA_DIR = "/tmp/chroma_db"
29
  HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
 
 
 
 
 
 
 
 
30
 
31
- # Correct provider suffixes verified from HuggingFace docs (2025)
32
- # Format: "model-id:provider"
33
- # cerebras = fast free GPU, hf-inference = HF own CPU servers
34
  CANDIDATE_MODELS = [
35
- "meta-llama/Llama-3.1-8B-Instruct:cerebras", # fast, free, no reasoning leak
36
- "meta-llama/Llama-3.3-70B-Instruct:cerebras", # larger, still free on cerebras
37
- "mistralai/Mistral-7B-Instruct-v0.3:fireworks-ai", # fireworks free tier
38
- "HuggingFaceTB/SmolLM3-3B:hf-inference", # HF's own server, always available
39
  ]
40
 
41
 
 
 
 
 
42
  class RAGEngine:
43
  def __init__(self):
44
- self._embeddings = None
45
- self._vectorstore = None
46
- self._splitter = RecursiveCharacterTextSplitter(
47
  chunk_size=CHUNK_SIZE,
48
  chunk_overlap=CHUNK_OVERLAP,
49
  separators=["\n\n", "\n", ". ", " ", ""],
50
  )
 
 
 
51
  monitor.log_startup()
52
 
53
  @property
54
  def embeddings(self):
55
  if self._embeddings is None:
 
56
  self._embeddings = HuggingFaceEmbeddings(
57
  model_name=EMBED_MODEL,
58
  model_kwargs={"device": "cpu"},
@@ -60,35 +81,198 @@ class RAGEngine:
60
  )
61
  return self._embeddings
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def ingest_file(self, uploaded_file) -> int:
64
- t0 = time.time()
65
- suffix = get_suffix(uploaded_file.name)
66
- error = ""
67
- chunks = 0
 
 
 
 
 
 
 
68
  try:
69
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
70
- tmp.write(uploaded_file.read())
71
- tmp_path = tmp.name
72
- chunks = self.ingest_path(tmp_path, uploaded_file.name)
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
  error = str(e)
 
75
  raise
76
  finally:
77
- monitor.log_ingestion(
78
- filename = uploaded_file.name,
79
- chunk_count = chunks,
80
- latency_ms = (time.time() - t0) * 1000,
81
- error = error,
82
- )
83
  return chunks
84
 
85
  def ingest_path(self, path: str, name: str = "") -> int:
86
- suffix = get_suffix(name or path)
87
- loader = PyPDFLoader(path) if suffix == ".pdf" else TextLoader(path, encoding="utf-8")
88
- raw_docs = loader.load()
89
- for doc in raw_docs:
90
- doc.metadata["source"] = name or os.path.basename(path)
91
- chunks = self._splitter.split_documents(raw_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if self._vectorstore is not None:
93
  try:
94
  self._vectorstore._client.reset()
@@ -96,160 +280,157 @@ class RAGEngine:
96
  pass
97
  self._vectorstore = None
98
  self._vectorstore = Chroma.from_documents(
99
- documents = chunks,
100
- embedding = self.embeddings,
101
- collection_name = COLLECTION_NAME,
102
- persist_directory = CHROMA_DIR,
103
- client_settings = Settings(anonymized_telemetry=False),
104
  )
 
105
  return len(chunks)
106
 
 
 
107
  def query(self, question: str) -> Tuple[str, List[str]]:
108
  if self._vectorstore is None:
109
  return "Please upload a document first.", []
110
 
111
- t0 = time.time()
112
- error = ""
113
- answer = ""
114
- sources = []
115
- model_used = ""
116
 
117
  try:
118
  retriever = self._vectorstore.as_retriever(
119
  search_type="mmr",
120
- search_kwargs={"k": TOP_K, "fetch_k": TOP_K * 2},
121
  )
122
- docs = retriever.invoke(question)
123
  context = "\n\n---\n\n".join(
124
- "[Chunk {}]\n{}".format(i + 1, d.page_content) for i, d in enumerate(docs)
 
125
  )
126
- sources = list({d.metadata.get("source", "Document") for d in docs})
127
  answer, model_used = self._generate(question, context)
 
 
128
  except Exception as e:
129
  error = str(e)
130
- answer = "Error: " + error
 
131
  finally:
132
- monitor.log_query(
133
- question = question,
134
- answer = answer,
135
- sources = sources,
136
- latency_ms = (time.time() - t0) * 1000,
137
- model_used = model_used,
138
- chunk_count = TOP_K,
139
- error = error,
140
- )
141
 
142
  return answer, sources
143
 
 
 
144
  def _generate(self, question: str, context: str) -> Tuple[str, str]:
145
  hf_token = os.environ.get("HF_TOKEN", "")
146
  if not hf_token:
147
  return (
148
  "HF_TOKEN not set. Add it as a Secret in Space Settings.\n\n"
149
- "Best matching excerpt:\n\n" + extract_best(question, context),
150
  "none"
151
  )
152
 
 
 
 
 
 
 
153
  system_prompt = (
154
- "You are DocMind, a document Q&A assistant. "
155
- "Answer the question using only the document context. "
156
- "Be short and direct. No preamble. No reasoning. Just answer."
 
 
 
157
  )
158
- user_message = (
159
- "Context:\n" + context +
160
- "\n\n---\nQuestion: " + question +
161
- "\nAnswer:"
162
- )
163
- headers = {
164
- "Authorization": "Bearer " + hf_token,
165
- "Content-Type": "application/json",
166
- }
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  last_error = ""
 
169
  for model_id in CANDIDATE_MODELS:
170
  try:
171
- payload = {
172
- "model": model_id,
173
- "messages": [
174
- {"role": "system", "content": system_prompt},
175
- {"role": "user", "content": user_message},
176
- ],
177
- "max_tokens": 400,
178
- "temperature": 0.05,
179
- "stream": False,
180
- }
181
  resp = requests.post(
182
  HF_API_URL,
183
  headers=headers,
184
- data=json.dumps(payload),
 
 
 
 
 
 
185
  timeout=60,
186
  )
187
  if resp.status_code == 200:
188
  raw = resp.json()["choices"][0]["message"]["content"].strip()
189
- answer = strip_thinking(raw)
190
  if answer:
191
  return answer, model_id
192
  else:
193
- last_error = "Model {} -> {}: {}".format(
194
- model_id, resp.status_code, resp.text[:200]
195
- )
196
- print("[DocMind] " + last_error)
197
  except Exception as e:
198
  last_error = str(e)
199
- print("[DocMind] Exception on {}: {}".format(model_id, last_error))
200
  continue
201
 
202
- fallback = (
203
  "AI unavailable. Most relevant excerpt:\n\n"
204
- + extract_best(question, context)
205
- + "\n\n(Error: " + last_error + ")"
 
206
  )
207
- return fallback, "fallback"
208
 
209
 
210
- def strip_thinking(text: str) -> str:
 
 
211
  text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
212
- reasoning_starters = [
213
  "okay", "ok,", "alright", "let me", "let's", "i need", "i will",
214
  "i'll", "first,", "so,", "the user", "looking at", "going through",
215
  "based on the chunk", "parsing", "to answer", "in order to",
216
  ]
217
- lines = text.split("\n")
218
- clean = []
219
- found_real = False
220
  for line in lines:
221
- lower = line.strip().lower()
222
- is_thinking = any(lower.startswith(p) for p in reasoning_starters)
223
- if not found_real:
224
- if line.strip() and not is_thinking:
225
- found_real = True
226
  clean.append(line)
227
  else:
228
  clean.append(line)
229
- result = "\n".join(clean).strip()
230
- if not result or len(result) > 1500:
231
- paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
232
- if paragraphs:
233
- last = paragraphs[-1]
234
- if len(last) < 800:
235
- return last
236
- return result if result else text
237
 
238
 
239
- def extract_best(question: str, context: str) -> str:
240
  keywords = set(re.findall(r'\b\w{4,}\b', question.lower()))
241
- best_chunk = ""
242
- best_score = 0
243
  for chunk in context.split("---"):
244
- words = set(re.findall(r'\b\w{4,}\b', chunk.lower()))
245
- score = len(keywords & words)
246
- if score > best_score:
247
- best_score = score
248
- best_chunk = chunk.strip()
249
- if not best_chunk:
250
- return "No relevant content found."
251
- return best_chunk[:600] + ("..." if len(best_chunk) > 600 else "")
252
-
253
-
254
- def get_suffix(name: str) -> str:
255
- return os.path.splitext(name)[-1].lower() or ".txt"
 
1
  """
2
+ rag_engine.py Multimodal RAG Engine with Conversation Memory
3
+ Supports: PDF, TXT, DOCX, CSV, XLSX, Images (JPG/PNG/WEBP)
4
+ Memory: sliding window of last 6 exchanges
 
5
  """
6
 
7
  import os
8
  import re
9
+ import io
10
  import json
11
  import time
12
  import tempfile
13
  import requests
14
+ import logging
15
+ from pathlib import Path
16
+ from typing import Tuple, List, Optional
17
 
18
  from chromadb.config import Settings
19
  from langchain.text_splitter import RecursiveCharacterTextSplitter
20
  from langchain_community.vectorstores import Chroma
 
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
+ from langchain_community.document_loaders import PyPDFLoader, TextLoader
23
+ from langchain.schema import Document
24
+
25
  import monitor
26
 
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # ── Constants ────────────────────────────────────────────────────────────────
31
  EMBED_MODEL = "all-MiniLM-L6-v2"
32
  CHUNK_SIZE = 600
33
  CHUNK_OVERLAP = 100
34
+ TOP_K = 4
35
+ COLLECTION_NAME = "docmind_multimodal"
36
  CHROMA_DIR = "/tmp/chroma_db"
37
  HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
38
+ MEMORY_WINDOW = 6 # number of past Q&A pairs to keep
39
+
40
+ SUPPORTED_EXTENSIONS = {
41
+ ".pdf", ".txt",
42
+ ".docx", ".doc",
43
+ ".csv", ".xlsx", ".xls",
44
+ ".jpg", ".jpeg", ".png", ".webp",
45
+ }
46
 
 
 
 
47
  CANDIDATE_MODELS = [
48
+ "meta-llama/Llama-3.1-8B-Instruct:cerebras",
49
+ "meta-llama/Llama-3.3-70B-Instruct:cerebras",
50
+ "mistralai/Mistral-7B-Instruct-v0.3:fireworks-ai",
51
+ "HuggingFaceTB/SmolLM3-3B:hf-inference",
52
  ]
53
 
54
 
55
+ def get_suffix(name: str) -> str:
56
+ return Path(name).suffix.lower() or ".txt"
57
+
58
+
59
  class RAGEngine:
60
  def __init__(self):
61
+ self._embeddings: Optional[HuggingFaceEmbeddings] = None
62
+ self._vectorstore: Optional[Chroma] = None
63
+ self._splitter = RecursiveCharacterTextSplitter(
64
  chunk_size=CHUNK_SIZE,
65
  chunk_overlap=CHUNK_OVERLAP,
66
  separators=["\n\n", "\n", ". ", " ", ""],
67
  )
68
+ self._memory: List[dict] = []
69
+ self._doc_name: str = ""
70
+ self._doc_type: str = ""
71
  monitor.log_startup()
72
 
73
  @property
74
  def embeddings(self):
75
  if self._embeddings is None:
76
+ logger.info("Loading embedding model...")
77
  self._embeddings = HuggingFaceEmbeddings(
78
  model_name=EMBED_MODEL,
79
  model_kwargs={"device": "cpu"},
 
81
  )
82
  return self._embeddings
83
 
84
+ # ── Memory ───────────────────────────────────────────────────────────────
85
+
86
+ def clear_memory(self):
87
+ self._memory = []
88
+
89
+ def add_to_memory(self, question: str, answer: str):
90
+ self._memory.append({"role": "user", "content": question})
91
+ self._memory.append({"role": "assistant", "content": answer})
92
+ max_msgs = MEMORY_WINDOW * 2
93
+ if len(self._memory) > max_msgs:
94
+ self._memory = self._memory[-max_msgs:]
95
+
96
+ def get_memory_messages(self) -> List[dict]:
97
+ return self._memory.copy()
98
+
99
+ def get_memory_count(self) -> int:
100
+ return len(self._memory) // 2
101
+
102
+ # ── Ingestion ────────────────────────────────────────────────────────────
103
+
104
  def ingest_file(self, uploaded_file) -> int:
105
+ """Accept FastAPI UploadFile or Streamlit UploadedFile."""
106
+ t0 = time.time()
107
+ filename = getattr(uploaded_file, "name", None) or getattr(uploaded_file, "filename", "file")
108
+ suffix = get_suffix(filename)
109
+ error = ""
110
+ chunks = 0
111
+
112
+ if suffix not in SUPPORTED_EXTENSIONS:
113
+ raise ValueError(
114
+ f"Unsupported: {suffix}. Supported: {', '.join(sorted(SUPPORTED_EXTENSIONS))}"
115
+ )
116
  try:
117
+ if hasattr(uploaded_file, "read"):
118
+ data = uploaded_file.read()
119
+ if hasattr(uploaded_file, "seek"):
120
+ uploaded_file.seek(0)
121
+ else:
122
+ data = uploaded_file.file.read()
123
+
124
+ docs = self._route(data, filename, suffix)
125
+ chunks = self._index(docs, filename)
126
+ self._doc_name = filename
127
+ self._doc_type = suffix
128
+ self.clear_memory()
129
  except Exception as e:
130
  error = str(e)
131
+ logger.error(f"Ingestion error: {e}")
132
  raise
133
  finally:
134
+ monitor.log_ingestion(filename, chunks, (time.time()-t0)*1000, error)
 
 
 
 
 
135
  return chunks
136
 
137
  def ingest_path(self, path: str, name: str = "") -> int:
138
+ filename = name or Path(path).name
139
+ suffix = get_suffix(filename)
140
+ with open(path, "rb") as f:
141
+ data = f.read()
142
+ docs = self._route(data, filename, suffix)
143
+ chunks = self._index(docs, filename)
144
+ self._doc_name = filename
145
+ self._doc_type = suffix
146
+ self.clear_memory()
147
+ return chunks
148
+
149
+ def _route(self, data: bytes, filename: str, suffix: str) -> List[Document]:
150
+ if suffix == ".pdf":
151
+ return self._load_pdf(data, filename)
152
+ elif suffix == ".txt":
153
+ return self._load_text(data, filename)
154
+ elif suffix in {".docx", ".doc"}:
155
+ return self._load_docx(data, filename)
156
+ elif suffix == ".csv":
157
+ return self._load_csv(data, filename)
158
+ elif suffix in {".xlsx", ".xls"}:
159
+ return self._load_excel(data, filename)
160
+ elif suffix in {".jpg", ".jpeg", ".png", ".webp"}:
161
+ return self._load_image(data, filename)
162
+ raise ValueError(f"No loader for {suffix}")
163
+
164
+ # ── Loaders ──────────────────────────────────────────────────────────────
165
+
166
+ def _load_pdf(self, data: bytes, filename: str) -> List[Document]:
167
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
168
+ tmp.write(data)
169
+ tmp_path = tmp.name
170
+ try:
171
+ docs = PyPDFLoader(tmp_path).load()
172
+ for doc in docs:
173
+ doc.metadata.update({"source": filename, "type": "pdf"})
174
+ return docs
175
+ finally:
176
+ os.unlink(tmp_path)
177
+
178
+ def _load_text(self, data: bytes, filename: str) -> List[Document]:
179
+ return [Document(
180
+ page_content=data.decode("utf-8", errors="replace"),
181
+ metadata={"source": filename, "type": "text"}
182
+ )]
183
+
184
+ def _load_docx(self, data: bytes, filename: str) -> List[Document]:
185
+ try:
186
+ import docx2txt
187
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp:
188
+ tmp.write(data)
189
+ tmp_path = tmp.name
190
+ try:
191
+ text = docx2txt.process(tmp_path)
192
+ finally:
193
+ os.unlink(tmp_path)
194
+ except ImportError:
195
+ text = data.decode("utf-8", errors="replace")
196
+ return [Document(page_content=text, metadata={"source": filename, "type": "docx"})]
197
+
198
+ def _load_csv(self, data: bytes, filename: str) -> List[Document]:
199
+ import pandas as pd
200
+ df = pd.read_csv(io.BytesIO(data))
201
+ docs = []
202
+
203
+ summary = (
204
+ f"File: {filename}\n"
205
+ f"Shape: {df.shape[0]} rows × {df.shape[1]} columns\n"
206
+ f"Columns: {', '.join(df.columns.tolist())}\n\n"
207
+ f"First 10 rows:\n{df.head(10).to_string(index=False)}"
208
+ )
209
+ docs.append(Document(page_content=summary, metadata={"source": filename, "type": "csv_summary"}))
210
+
211
+ try:
212
+ stats = "Statistical summary:\n" + df.describe(include="all").to_string()
213
+ docs.append(Document(page_content=stats, metadata={"source": filename, "type": "csv_stats"}))
214
+ except Exception:
215
+ pass
216
+
217
+ for i in range(0, min(len(df), 500), 50):
218
+ chunk = f"Rows {i}–{i+50}:\n{df.iloc[i:i+50].to_string(index=False)}"
219
+ docs.append(Document(page_content=chunk, metadata={"source": filename, "type": "csv_rows"}))
220
+
221
+ return docs
222
+
223
+ def _load_excel(self, data: bytes, filename: str) -> List[Document]:
224
+ import pandas as pd
225
+ xl = pd.ExcelFile(io.BytesIO(data))
226
+ docs = []
227
+ for sheet in xl.sheet_names:
228
+ df = xl.parse(sheet)
229
+ text = (
230
+ f"Sheet: {sheet} | {df.shape[0]} rows × {df.shape[1]} cols\n"
231
+ f"Columns: {', '.join(str(c) for c in df.columns)}\n\n"
232
+ f"{df.head(10).to_string(index=False)}"
233
+ )
234
+ docs.append(Document(page_content=text, metadata={"source": filename, "type": "excel", "sheet": sheet}))
235
+ return docs
236
+
237
+ def _load_image(self, data: bytes, filename: str) -> List[Document]:
238
+ caption = self._caption_image(data, filename)
239
+ text = (
240
+ f"Image file: {filename}\n\n"
241
+ f"AI-generated image description:\n{caption}\n\n"
242
+ f"The above description represents the full visual content of this image."
243
+ )
244
+ return [Document(
245
+ page_content=text,
246
+ metadata={"source": filename, "type": "image", "caption": caption}
247
+ )]
248
+
249
+ def _caption_image(self, data: bytes, filename: str) -> str:
250
+ hf_token = os.environ.get("HF_TOKEN", "")
251
+ if not hf_token:
252
+ return f"[Image: {filename}] — Add HF_TOKEN secret to enable AI image captioning."
253
+ try:
254
+ import base64
255
+ resp = requests.post(
256
+ "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large",
257
+ headers={"Authorization": f"Bearer {hf_token}"},
258
+ json={"inputs": base64.b64encode(data).decode()},
259
+ timeout=30,
260
+ )
261
+ if resp.status_code == 200:
262
+ result = resp.json()
263
+ if isinstance(result, list) and result:
264
+ caption = result[0].get("generated_text", "")
265
+ if caption:
266
+ logger.info(f"Image caption: {caption[:80]}")
267
+ return caption
268
+ except Exception as e:
269
+ logger.warning(f"Caption failed: {e}")
270
+ return f"[Image: {filename}] — Visual content uploaded (captioning unavailable)"
271
+
272
+ # ── Indexing ─────────────────────────────────────────────────────────────
273
+
274
+ def _index(self, docs: List[Document], filename: str) -> int:
275
+ chunks = self._splitter.split_documents(docs)
276
  if self._vectorstore is not None:
277
  try:
278
  self._vectorstore._client.reset()
 
280
  pass
281
  self._vectorstore = None
282
  self._vectorstore = Chroma.from_documents(
283
+ documents=chunks,
284
+ embedding=self.embeddings,
285
+ collection_name=COLLECTION_NAME,
286
+ persist_directory=CHROMA_DIR,
287
+ client_settings=Settings(anonymized_telemetry=False),
288
  )
289
+ logger.info(f"Indexed {len(chunks)} chunks from {filename}")
290
  return len(chunks)
291
 
292
+ # ── Query ────────────────────────────────────────────────────────────────
293
+
294
  def query(self, question: str) -> Tuple[str, List[str]]:
295
  if self._vectorstore is None:
296
  return "Please upload a document first.", []
297
 
298
+ t0 = time.time()
299
+ error = answer = model_used = ""
300
+ sources = []
 
 
301
 
302
  try:
303
  retriever = self._vectorstore.as_retriever(
304
  search_type="mmr",
305
+ search_kwargs={"k": TOP_K, "fetch_k": TOP_K * 3},
306
  )
307
+ docs = retriever.invoke(question)
308
  context = "\n\n---\n\n".join(
309
+ f"[Chunk {i+1} | {d.metadata.get('type','text')}]\n{d.page_content}"
310
+ for i, d in enumerate(docs)
311
  )
312
+ sources = list({d.metadata.get("source", "Document") for d in docs})
313
  answer, model_used = self._generate(question, context)
314
+ self.add_to_memory(question, answer)
315
+
316
  except Exception as e:
317
  error = str(e)
318
+ answer = f"Error: {error}"
319
+ logger.error(f"Query error: {e}")
320
  finally:
321
+ monitor.log_query(question, answer, sources, (time.time()-t0)*1000, model_used, TOP_K, error)
 
 
 
 
 
 
 
 
322
 
323
  return answer, sources
324
 
325
+ # ── LLM ──────────────────────────────────────────────────────────────────
326
+
327
  def _generate(self, question: str, context: str) -> Tuple[str, str]:
328
  hf_token = os.environ.get("HF_TOKEN", "")
329
  if not hf_token:
330
  return (
331
  "HF_TOKEN not set. Add it as a Secret in Space Settings.\n\n"
332
+ "Best matching excerpt:\n\n" + _extract_best(question, context),
333
  "none"
334
  )
335
 
336
+ doc_type_hint = ""
337
+ if self._doc_type in {".jpg", ".jpeg", ".png", ".webp"}:
338
+ doc_type_hint = "The document is an IMAGE described by an AI caption. Base your answer on the caption."
339
+ elif self._doc_type in {".csv", ".xlsx", ".xls"}:
340
+ doc_type_hint = "The document is tabular data (spreadsheet/CSV). Refer to column names and values precisely."
341
+
342
  system_prompt = (
343
+ f"You are DocMind AI, an expert document analyst built by Ryan Farahani.\n"
344
+ f"You are analyzing: '{self._doc_name}'.\n"
345
+ f"{doc_type_hint}\n"
346
+ "Answer using ONLY the provided document context. "
347
+ "Be concise and precise. No preamble. No reasoning out loud. Just answer.\n"
348
+ "If asked a follow-up question, use the conversation history for context."
349
  )
 
 
 
 
 
 
 
 
 
350
 
351
+ # Build messages with memory
352
+ messages = [{"role": "system", "content": system_prompt}]
353
+ memory = self.get_memory_messages()
354
+
355
+ if memory:
356
+ # Context injection before history
357
+ messages.append({
358
+ "role": "system",
359
+ "content": f"Current document context:\n{context}"
360
+ })
361
+ messages.extend(memory)
362
+ messages.append({"role": "user", "content": question})
363
+ else:
364
+ messages.append({
365
+ "role": "user",
366
+ "content": f"Document context:\n{context}\n\n---\nQuestion: {question}"
367
+ })
368
+
369
+ headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
370
  last_error = ""
371
+
372
  for model_id in CANDIDATE_MODELS:
373
  try:
 
 
 
 
 
 
 
 
 
 
374
  resp = requests.post(
375
  HF_API_URL,
376
  headers=headers,
377
+ data=json.dumps({
378
+ "model": model_id,
379
+ "messages": messages,
380
+ "max_tokens": 500,
381
+ "temperature": 0.1,
382
+ "stream": False,
383
+ }),
384
  timeout=60,
385
  )
386
  if resp.status_code == 200:
387
  raw = resp.json()["choices"][0]["message"]["content"].strip()
388
+ answer = _strip_thinking(raw)
389
  if answer:
390
  return answer, model_id
391
  else:
392
+ last_error = f"{model_id} {resp.status_code}: {resp.text[:150]}"
393
+ logger.warning(last_error)
 
 
394
  except Exception as e:
395
  last_error = str(e)
396
+ logger.warning(f"Exception on {model_id}: {e}")
397
  continue
398
 
399
+ return (
400
  "AI unavailable. Most relevant excerpt:\n\n"
401
+ + _extract_best(question, context)
402
+ + f"\n\n(Error: {last_error})",
403
+ "fallback"
404
  )
 
405
 
406
 
407
+ # ── Helpers ──────────────────────────────────────────────────────────────────
408
+
409
+ def _strip_thinking(text: str) -> str:
410
  text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
411
+ starters = [
412
  "okay", "ok,", "alright", "let me", "let's", "i need", "i will",
413
  "i'll", "first,", "so,", "the user", "looking at", "going through",
414
  "based on the chunk", "parsing", "to answer", "in order to",
415
  ]
416
+ lines = text.split("\n")
417
+ clean, found = [], False
 
418
  for line in lines:
419
+ lower = line.strip().lower()
420
+ if not found:
421
+ if line.strip() and not any(lower.startswith(p) for p in starters):
422
+ found = True
 
423
  clean.append(line)
424
  else:
425
  clean.append(line)
426
+ return "\n".join(clean).strip() or text
 
 
 
 
 
 
 
427
 
428
 
429
+ def _extract_best(question: str, context: str) -> str:
430
  keywords = set(re.findall(r'\b\w{4,}\b', question.lower()))
431
+ best, score = "", 0
 
432
  for chunk in context.split("---"):
433
+ s = len(keywords & set(re.findall(r'\b\w{4,}\b', chunk.lower())))
434
+ if s > score:
435
+ score, best = s, chunk.strip()
436
+ return (best[:600] + "...") if len(best) > 600 else best or "No relevant content found."