ShanenThomas commited on
Commit
685d934
·
verified ·
1 Parent(s): 3459268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -88
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import faiss
3
  import numpy as np
4
  import gradio as gr
5
  from typing import List, Tuple
@@ -7,69 +6,96 @@ from pypdf import PdfReader
7
  from sentence_transformers import SentenceTransformer
8
  from huggingface_hub import InferenceClient
9
 
10
- # ==============================
11
  # Config
12
- # ==============================
13
- GEN_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
 
 
14
  HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets
15
  EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
  CHUNK_SIZE = 900
17
  CHUNK_OVERLAP = 150
18
  TOP_K = 4
19
 
20
- # ==============================
21
- # Globals (lifetime of the Space)
22
- # ==============================
 
 
 
 
 
 
 
 
 
23
  emb = SentenceTransformer(EMB_MODEL_NAME)
24
- index = None # FAISS index (inner product / cosine)
25
- doc_chunks: List[str] = [] # text chunks
26
- doc_meta: List[dict] = [] # {"file": "..."}
 
27
  client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
28
 
29
- # ==============================
 
 
 
 
 
 
30
  # Helpers
31
- # ==============================
 
 
 
 
 
32
  def _chunk_text(text: str, size: int, overlap: int) -> List[str]:
33
- chunks = []
34
- start = 0
35
- n = len(text)
36
- step = size - overlap
37
- while start < n:
38
- end = min(start + size, n)
39
- chunks.append(text[start:end])
40
- start += step
41
- return [c.strip() for c in chunks if c.strip()]
42
 
43
  def _embed(texts: List[str]) -> np.ndarray:
44
- # 384-d for MiniLM; normalize for cosine/IP search
45
  X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
46
  return np.asarray(X, dtype=np.float32)
47
 
48
  def _ensure_index(dim: int):
49
- global index
50
- index = faiss.IndexFlatIP(dim) # cosine via normalized vectors
51
-
52
- def _extract_text_from_pdf(path: str) -> str:
53
- reader = PdfReader(path)
54
- pages = []
55
- for p in reader.pages:
56
- t = p.extract_text() or ""
57
- pages.append(t)
58
- return "\n".join(pages)
59
-
60
- # ==============================
 
 
 
 
 
 
 
 
 
 
 
61
  # Build index
62
- # ==============================
63
  def build_from_pdfs(files) -> str:
64
- global index, doc_chunks, doc_meta
65
  doc_chunks, doc_meta = [], []
66
 
67
- # 1) read PDFs → 2) chunk → collect
68
  for f in files:
69
- try:
70
- text = _extract_text_from_pdf(f.name)
71
- except Exception as e:
72
- return f"Failed to read {os.path.basename(f.name)}: {e}"
73
  chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
74
  for c in chunks:
75
  doc_chunks.append(c)
@@ -78,80 +104,92 @@ def build_from_pdfs(files) -> str:
78
  if not doc_chunks:
79
  return "No text extracted. Check your PDFs."
80
 
81
- # 3) embeddings → FAISS
82
  E = _embed(doc_chunks)
83
  _ensure_index(E.shape[1])
84
- index.add(E)
85
-
86
  return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)."
87
 
88
- # ==============================
89
  # Retrieval + Generation
90
- # ==============================
91
  def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]:
92
- qv = _embed([query]) # shape (1, d)
93
- sims, idxs = index.search(qv, k) # inner product similarity
94
- ids = idxs[0].tolist()
95
- # Filter out -1 (in case FAISS returns for empty)
96
- ids = [i for i in ids if i >= 0]
97
  return ids, [doc_chunks[i] for i in ids]
98
 
99
- SYSTEM_PROMPT = (
100
- "You are a helpful assistant. Use the given CONTEXT to answer the QUESTION.\n"
101
- "If the answer is not in the context, say you don't know.\n"
102
- "Provide a concise answer and list source filenames as [source: file.pdf] at the end."
103
- )
104
-
105
- def _mistral_prompt(question: str, context: str) -> str:
106
- # Simple Mistral-instruct prompt format
107
- return (
108
- f"[INST] {SYSTEM_PROMPT}\n\n"
109
- f"QUESTION: {question}\n\n"
110
- f"CONTEXT:\n{context}\n"
111
- f"[/INST]"
112
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def answer(question: str) -> str:
115
  if not question.strip():
116
  return "Ask a question."
117
- if index is None or not doc_chunks:
118
  return "Upload PDFs and click **Build Index** first."
119
 
120
  ids, ctx_chunks = _retrieve(question, TOP_K)
121
- # keep contexts reasonably short per chunk
122
- previews = []
123
- contexts = []
124
- files = []
125
  for rank, i in enumerate(ids, start=1):
126
  chunk = doc_chunks[i][:1000]
127
  fname = doc_meta[i]["file"]
128
  contexts.append(f"[{rank}] {fname}\n{chunk}")
129
- previews.append(f"[{rank}] {fname}")
130
  files.append(fname)
131
 
132
  context_str = "\n\n---\n".join(contexts)
133
- prompt = _mistral_prompt(question, context_str)
 
 
 
134
 
135
  try:
136
- # Use hosted Inference API; returns a single string
137
- out = client.text_generation(
138
- prompt,
139
- max_new_tokens=512,
140
- temperature=0.2,
141
- top_p=0.95,
142
- repetition_penalty=1.05,
143
- do_sample=True,
144
- return_full_text=False,
145
- )
146
- # Ensure sources are visible at the end
147
- unique_files = ", ".join(sorted(set(files)))
148
  return f"{out.strip()}\n\nSources: {unique_files}"
149
  except Exception as e:
150
- return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model name.)"
151
 
152
- # ==============================
153
  # UI
154
- # ==============================
155
  with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
156
  gr.Markdown("# 📚 PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs → Build Index → Ask questions. Answers cite sources.")
157
 
@@ -167,7 +205,7 @@ with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
167
 
168
  build_btn.click(build_from_pdfs, inputs=[files], outputs=[status])
169
  ask_btn.click(answer, inputs=[q], outputs=[a])
170
- q.submit(answer, inputs=[q], outputs=[a]) # hit Enter to ask
171
 
172
  if __name__ == "__main__":
173
  demo.launch()
 
1
  import os
 
2
  import numpy as np
3
  import gradio as gr
4
  from typing import List, Tuple
 
6
  from sentence_transformers import SentenceTransformer
7
  from huggingface_hub import InferenceClient
8
 
9
+ # -------------------------------------------------
10
  # Config
11
+ # -------------------------------------------------
12
+ # You can swap to another chat model if needed, e.g.:
13
+ # "mistralai/Mistral-Nemo-Instruct-2407" or "meta-llama/Llama-3.1-8B-Instruct"
14
+ GEN_MODEL = os.getenv("GEN_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
15
  HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets
16
  EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
17
  CHUNK_SIZE = 900
18
  CHUNK_OVERLAP = 150
19
  TOP_K = 4
20
 
21
+ # -------------------------------------------------
22
+ # Try FAISS; fallback to pure NumPy search
23
+ # -------------------------------------------------
24
+ USE_FAISS = True
25
+ try:
26
+ import faiss # type: ignore
27
+ except Exception:
28
+ USE_FAISS = False
29
+
30
+ # -------------------------------------------------
31
+ # Globals
32
+ # -------------------------------------------------
33
  emb = SentenceTransformer(EMB_MODEL_NAME)
34
+ index = None # FAISS index (if available)
35
+ matrix = None # fallback: stacked embeddings
36
+ doc_chunks: List[str] = []
37
+ doc_meta: List[dict] = []
38
  client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
39
 
40
+ SYSTEM_PROMPT = (
41
+ "You are a helpful assistant. Use the given CONTEXT to answer the QUESTION.\n"
42
+ "If the answer is not in the context, say you don't know.\n"
43
+ "Be concise and list source filenames as [source: file.pdf] at the end."
44
+ )
45
+
46
+ # -------------------------------------------------
47
  # Helpers
48
+ # -------------------------------------------------
49
+ def _extract_text_from_pdf(path: str) -> str:
50
+ r = PdfReader(path)
51
+ pages = [(p.extract_text() or "") for p in r.pages]
52
+ return "\n".join(pages)
53
+
54
  def _chunk_text(text: str, size: int, overlap: int) -> List[str]:
55
+ chunks, step = [], size - overlap
56
+ i, n = 0, len(text)
57
+ while i < n:
58
+ chunk = text[i:i+size].strip()
59
+ if chunk: chunks.append(chunk)
60
+ i += step
61
+ return chunks
 
 
62
 
63
  def _embed(texts: List[str]) -> np.ndarray:
 
64
  X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
65
  return np.asarray(X, dtype=np.float32)
66
 
67
  def _ensure_index(dim: int):
68
+ global index, matrix
69
+ if USE_FAISS:
70
+ index = faiss.IndexFlatIP(dim)
71
+ else:
72
+ index = None
73
+ matrix = None
74
+
75
+ def _add_embeddings(E: np.ndarray):
76
+ global matrix
77
+ if USE_FAISS:
78
+ index.add(E)
79
+ else:
80
+ matrix = E if matrix is None else np.vstack([matrix, E])
81
+
82
+ def _search(qv: np.ndarray, k: int):
83
+ if USE_FAISS:
84
+ return index.search(qv, k) # returns (D, I)
85
+ sims = matrix @ qv[0] # IP because vectors are normalized
86
+ I = np.argsort(-sims)[:k]
87
+ D = sims[I]
88
+ return D[None, :], I[None, :]
89
+
90
+ # -------------------------------------------------
91
  # Build index
92
+ # -------------------------------------------------
93
  def build_from_pdfs(files) -> str:
94
+ global doc_chunks, doc_meta
95
  doc_chunks, doc_meta = [], []
96
 
 
97
  for f in files:
98
+ text = _extract_text_from_pdf(f.name)
 
 
 
99
  chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
100
  for c in chunks:
101
  doc_chunks.append(c)
 
104
  if not doc_chunks:
105
  return "No text extracted. Check your PDFs."
106
 
 
107
  E = _embed(doc_chunks)
108
  _ensure_index(E.shape[1])
109
+ _add_embeddings(E)
 
110
  return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)."
111
 
112
+ # -------------------------------------------------
113
  # Retrieval + Generation
114
+ # -------------------------------------------------
115
  def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]:
116
+ qv = _embed([query])
117
+ _, idxs = _search(qv, k)
118
+ ids = [i for i in idxs[0].tolist() if i >= 0]
 
 
119
  return ids, [doc_chunks[i] for i in ids]
120
 
121
+ def _call_chat(messages):
122
+ """
123
+ Try several Hugging Face client paths for max compatibility.
124
+ Returns generated string or raises last exception.
125
+ """
126
+ # 1) Newer helper
127
+ try:
128
+ resp = client.chat_completion(messages=messages, max_tokens=512, temperature=0.2, top_p=0.95)
129
+ # resp.choices[0].message.content (object or dict)
130
+ choice = resp.choices[0]
131
+ msg = getattr(choice, "message", None) or choice["message"]
132
+ return getattr(msg, "content", None) or msg["content"]
133
+ except Exception as e1:
134
+ last = e1
135
+ # 2) OpenAI-style
136
+ try:
137
+ resp = client.chat.completions.create(model=GEN_MODEL, messages=messages, max_tokens=512, temperature=0.2, top_p=0.95)
138
+ choice = resp.choices[0]
139
+ msg = getattr(choice, "message", None) or choice["message"]
140
+ return getattr(msg, "content", None) or msg["content"]
141
+ except Exception as e2:
142
+ last = e2
143
+ # 3) Text generation with a single composed prompt
144
+ try:
145
+ prompt = f"[INST] {SYSTEM_PROMPT}\n\n{messages[-1]['content']} [/INST]"
146
+ return client.text_generation(prompt, max_new_tokens=512, temperature=0.2, top_p=0.95,
147
+ repetition_penalty=1.05, do_sample=True, return_full_text=False).strip()
148
+ except Exception as e3:
149
+ last = e3
150
+ # 4) Old conversational task
151
+ try:
152
+ conv = client.conversational(
153
+ past_user_inputs=[],
154
+ generated_responses=[],
155
+ text=messages[-1]["content"],
156
+ parameters={"temperature": 0.2, "max_new_tokens": 512},
157
+ )
158
+ return conv["generated_text"] if isinstance(conv, dict) else conv.generated_text
159
+ except Exception as e4:
160
+ last = e4
161
+ raise last
162
 
163
  def answer(question: str) -> str:
164
  if not question.strip():
165
  return "Ask a question."
166
+ if (USE_FAISS and index is None) or (not USE_FAISS and matrix is None) or not doc_chunks:
167
  return "Upload PDFs and click **Build Index** first."
168
 
169
  ids, ctx_chunks = _retrieve(question, TOP_K)
170
+ contexts, files = [], []
 
 
 
171
  for rank, i in enumerate(ids, start=1):
172
  chunk = doc_chunks[i][:1000]
173
  fname = doc_meta[i]["file"]
174
  contexts.append(f"[{rank}] {fname}\n{chunk}")
 
175
  files.append(fname)
176
 
177
  context_str = "\n\n---\n".join(contexts)
178
+ messages = [
179
+ {"role": "system", "content": SYSTEM_PROMPT},
180
+ {"role": "user", "content": f"QUESTION: {question}\n\nCONTEXT:\n{context_str}"},
181
+ ]
182
 
183
  try:
184
+ out = _call_chat(messages)
185
+ unique_files = ", ".join(sorted(set(files))) if files else "N/A"
 
 
 
 
 
 
 
 
 
 
186
  return f"{out.strip()}\n\nSources: {unique_files}"
187
  except Exception as e:
188
+ return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model availability.)"
189
 
190
+ # -------------------------------------------------
191
  # UI
192
+ # -------------------------------------------------
193
  with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
194
  gr.Markdown("# 📚 PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs → Build Index → Ask questions. Answers cite sources.")
195
 
 
205
 
206
  build_btn.click(build_from_pdfs, inputs=[files], outputs=[status])
207
  ask_btn.click(answer, inputs=[q], outputs=[a])
208
+ q.submit(answer, inputs=[q], outputs=[a])
209
 
210
  if __name__ == "__main__":
211
  demo.launch()