Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import faiss
|
| 3 |
import numpy as np
|
| 4 |
import gradio as gr
|
| 5 |
from typing import List, Tuple
|
|
@@ -7,69 +6,96 @@ from pypdf import PdfReader
|
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from huggingface_hub import InferenceClient
|
| 9 |
|
| 10 |
-
#
|
| 11 |
# Config
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets
|
| 15 |
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 16 |
CHUNK_SIZE = 900
|
| 17 |
CHUNK_OVERLAP = 150
|
| 18 |
TOP_K = 4
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
#
|
| 22 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
emb = SentenceTransformer(EMB_MODEL_NAME)
|
| 24 |
-
index = None
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
|
| 28 |
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Helpers
|
| 31 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def _chunk_text(text: str, size: int, overlap: int) -> List[str]:
|
| 33 |
-
chunks = []
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
start += step
|
| 41 |
-
return [c.strip() for c in chunks if c.strip()]
|
| 42 |
|
| 43 |
def _embed(texts: List[str]) -> np.ndarray:
|
| 44 |
-
# 384-d for MiniLM; normalize for cosine/IP search
|
| 45 |
X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 46 |
return np.asarray(X, dtype=np.float32)
|
| 47 |
|
| 48 |
def _ensure_index(dim: int):
|
| 49 |
-
global index
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Build index
|
| 62 |
-
#
|
| 63 |
def build_from_pdfs(files) -> str:
|
| 64 |
-
global
|
| 65 |
doc_chunks, doc_meta = [], []
|
| 66 |
|
| 67 |
-
# 1) read PDFs → 2) chunk → collect
|
| 68 |
for f in files:
|
| 69 |
-
|
| 70 |
-
text = _extract_text_from_pdf(f.name)
|
| 71 |
-
except Exception as e:
|
| 72 |
-
return f"Failed to read {os.path.basename(f.name)}: {e}"
|
| 73 |
chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
|
| 74 |
for c in chunks:
|
| 75 |
doc_chunks.append(c)
|
|
@@ -78,80 +104,92 @@ def build_from_pdfs(files) -> str:
|
|
| 78 |
if not doc_chunks:
|
| 79 |
return "No text extracted. Check your PDFs."
|
| 80 |
|
| 81 |
-
# 3) embeddings → FAISS
|
| 82 |
E = _embed(doc_chunks)
|
| 83 |
_ensure_index(E.shape[1])
|
| 84 |
-
|
| 85 |
-
|
| 86 |
return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)."
|
| 87 |
|
| 88 |
-
#
|
| 89 |
# Retrieval + Generation
|
| 90 |
-
#
|
| 91 |
def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]:
|
| 92 |
-
qv = _embed([query])
|
| 93 |
-
|
| 94 |
-
ids = idxs[0].tolist()
|
| 95 |
-
# Filter out -1 (in case FAISS returns for empty)
|
| 96 |
-
ids = [i for i in ids if i >= 0]
|
| 97 |
return ids, [doc_chunks[i] for i in ids]
|
| 98 |
|
| 99 |
-
|
| 100 |
-
"
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
def answer(question: str) -> str:
|
| 115 |
if not question.strip():
|
| 116 |
return "Ask a question."
|
| 117 |
-
if index is None or not doc_chunks:
|
| 118 |
return "Upload PDFs and click **Build Index** first."
|
| 119 |
|
| 120 |
ids, ctx_chunks = _retrieve(question, TOP_K)
|
| 121 |
-
|
| 122 |
-
previews = []
|
| 123 |
-
contexts = []
|
| 124 |
-
files = []
|
| 125 |
for rank, i in enumerate(ids, start=1):
|
| 126 |
chunk = doc_chunks[i][:1000]
|
| 127 |
fname = doc_meta[i]["file"]
|
| 128 |
contexts.append(f"[{rank}] {fname}\n{chunk}")
|
| 129 |
-
previews.append(f"[{rank}] {fname}")
|
| 130 |
files.append(fname)
|
| 131 |
|
| 132 |
context_str = "\n\n---\n".join(contexts)
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
try:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
prompt,
|
| 139 |
-
max_new_tokens=512,
|
| 140 |
-
temperature=0.2,
|
| 141 |
-
top_p=0.95,
|
| 142 |
-
repetition_penalty=1.05,
|
| 143 |
-
do_sample=True,
|
| 144 |
-
return_full_text=False,
|
| 145 |
-
)
|
| 146 |
-
# Ensure sources are visible at the end
|
| 147 |
-
unique_files = ", ".join(sorted(set(files)))
|
| 148 |
return f"{out.strip()}\n\nSources: {unique_files}"
|
| 149 |
except Exception as e:
|
| 150 |
-
return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model
|
| 151 |
|
| 152 |
-
#
|
| 153 |
# UI
|
| 154 |
-
#
|
| 155 |
with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
|
| 156 |
gr.Markdown("# 📚 PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs → Build Index → Ask questions. Answers cite sources.")
|
| 157 |
|
|
@@ -167,7 +205,7 @@ with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
|
|
| 167 |
|
| 168 |
build_btn.click(build_from_pdfs, inputs=[files], outputs=[status])
|
| 169 |
ask_btn.click(answer, inputs=[q], outputs=[a])
|
| 170 |
-
q.submit(answer, inputs=[q], outputs=[a])
|
| 171 |
|
| 172 |
if __name__ == "__main__":
|
| 173 |
demo.launch()
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
import gradio as gr
|
| 4 |
from typing import List, Tuple
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from huggingface_hub import InferenceClient
|
| 8 |
|
| 9 |
+
# -------------------------------------------------
|
| 10 |
# Config
|
| 11 |
+
# -------------------------------------------------
|
| 12 |
+
# You can swap to another chat model if needed, e.g.:
|
| 13 |
+
# "mistralai/Mistral-Nemo-Instruct-2407" or "meta-llama/Llama-3.1-8B-Instruct"
|
| 14 |
+
GEN_MODEL = os.getenv("GEN_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
| 15 |
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets
|
| 16 |
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 17 |
CHUNK_SIZE = 900
|
| 18 |
CHUNK_OVERLAP = 150
|
| 19 |
TOP_K = 4
|
| 20 |
|
| 21 |
+
# -------------------------------------------------
|
| 22 |
+
# Try FAISS; fallback to pure NumPy search
|
| 23 |
+
# -------------------------------------------------
|
| 24 |
+
USE_FAISS = True
|
| 25 |
+
try:
|
| 26 |
+
import faiss # type: ignore
|
| 27 |
+
except Exception:
|
| 28 |
+
USE_FAISS = False
|
| 29 |
+
|
| 30 |
+
# -------------------------------------------------
|
| 31 |
+
# Globals
|
| 32 |
+
# -------------------------------------------------
|
| 33 |
emb = SentenceTransformer(EMB_MODEL_NAME)
|
| 34 |
+
index = None # FAISS index (if available)
|
| 35 |
+
matrix = None # fallback: stacked embeddings
|
| 36 |
+
doc_chunks: List[str] = []
|
| 37 |
+
doc_meta: List[dict] = []
|
| 38 |
client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
|
| 39 |
|
| 40 |
+
SYSTEM_PROMPT = (
|
| 41 |
+
"You are a helpful assistant. Use the given CONTEXT to answer the QUESTION.\n"
|
| 42 |
+
"If the answer is not in the context, say you don't know.\n"
|
| 43 |
+
"Be concise and list source filenames as [source: file.pdf] at the end."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# -------------------------------------------------
|
| 47 |
# Helpers
|
| 48 |
+
# -------------------------------------------------
|
| 49 |
+
def _extract_text_from_pdf(path: str) -> str:
|
| 50 |
+
r = PdfReader(path)
|
| 51 |
+
pages = [(p.extract_text() or "") for p in r.pages]
|
| 52 |
+
return "\n".join(pages)
|
| 53 |
+
|
| 54 |
def _chunk_text(text: str, size: int, overlap: int) -> List[str]:
|
| 55 |
+
chunks, step = [], size - overlap
|
| 56 |
+
i, n = 0, len(text)
|
| 57 |
+
while i < n:
|
| 58 |
+
chunk = text[i:i+size].strip()
|
| 59 |
+
if chunk: chunks.append(chunk)
|
| 60 |
+
i += step
|
| 61 |
+
return chunks
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def _embed(texts: List[str]) -> np.ndarray:
|
|
|
|
| 64 |
X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 65 |
return np.asarray(X, dtype=np.float32)
|
| 66 |
|
| 67 |
def _ensure_index(dim: int):
|
| 68 |
+
global index, matrix
|
| 69 |
+
if USE_FAISS:
|
| 70 |
+
index = faiss.IndexFlatIP(dim)
|
| 71 |
+
else:
|
| 72 |
+
index = None
|
| 73 |
+
matrix = None
|
| 74 |
+
|
| 75 |
+
def _add_embeddings(E: np.ndarray):
|
| 76 |
+
global matrix
|
| 77 |
+
if USE_FAISS:
|
| 78 |
+
index.add(E)
|
| 79 |
+
else:
|
| 80 |
+
matrix = E if matrix is None else np.vstack([matrix, E])
|
| 81 |
+
|
| 82 |
+
def _search(qv: np.ndarray, k: int):
|
| 83 |
+
if USE_FAISS:
|
| 84 |
+
return index.search(qv, k) # returns (D, I)
|
| 85 |
+
sims = matrix @ qv[0] # IP because vectors are normalized
|
| 86 |
+
I = np.argsort(-sims)[:k]
|
| 87 |
+
D = sims[I]
|
| 88 |
+
return D[None, :], I[None, :]
|
| 89 |
+
|
| 90 |
+
# -------------------------------------------------
|
| 91 |
# Build index
|
| 92 |
+
# -------------------------------------------------
|
| 93 |
def build_from_pdfs(files) -> str:
|
| 94 |
+
global doc_chunks, doc_meta
|
| 95 |
doc_chunks, doc_meta = [], []
|
| 96 |
|
|
|
|
| 97 |
for f in files:
|
| 98 |
+
text = _extract_text_from_pdf(f.name)
|
|
|
|
|
|
|
|
|
|
| 99 |
chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
|
| 100 |
for c in chunks:
|
| 101 |
doc_chunks.append(c)
|
|
|
|
| 104 |
if not doc_chunks:
|
| 105 |
return "No text extracted. Check your PDFs."
|
| 106 |
|
|
|
|
| 107 |
E = _embed(doc_chunks)
|
| 108 |
_ensure_index(E.shape[1])
|
| 109 |
+
_add_embeddings(E)
|
|
|
|
| 110 |
return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)."
|
| 111 |
|
| 112 |
+
# -------------------------------------------------
|
| 113 |
# Retrieval + Generation
|
| 114 |
+
# -------------------------------------------------
|
| 115 |
def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]:
|
| 116 |
+
qv = _embed([query])
|
| 117 |
+
_, idxs = _search(qv, k)
|
| 118 |
+
ids = [i for i in idxs[0].tolist() if i >= 0]
|
|
|
|
|
|
|
| 119 |
return ids, [doc_chunks[i] for i in ids]
|
| 120 |
|
| 121 |
+
def _call_chat(messages):
|
| 122 |
+
"""
|
| 123 |
+
Try several Hugging Face client paths for max compatibility.
|
| 124 |
+
Returns generated string or raises last exception.
|
| 125 |
+
"""
|
| 126 |
+
# 1) Newer helper
|
| 127 |
+
try:
|
| 128 |
+
resp = client.chat_completion(messages=messages, max_tokens=512, temperature=0.2, top_p=0.95)
|
| 129 |
+
# resp.choices[0].message.content (object or dict)
|
| 130 |
+
choice = resp.choices[0]
|
| 131 |
+
msg = getattr(choice, "message", None) or choice["message"]
|
| 132 |
+
return getattr(msg, "content", None) or msg["content"]
|
| 133 |
+
except Exception as e1:
|
| 134 |
+
last = e1
|
| 135 |
+
# 2) OpenAI-style
|
| 136 |
+
try:
|
| 137 |
+
resp = client.chat.completions.create(model=GEN_MODEL, messages=messages, max_tokens=512, temperature=0.2, top_p=0.95)
|
| 138 |
+
choice = resp.choices[0]
|
| 139 |
+
msg = getattr(choice, "message", None) or choice["message"]
|
| 140 |
+
return getattr(msg, "content", None) or msg["content"]
|
| 141 |
+
except Exception as e2:
|
| 142 |
+
last = e2
|
| 143 |
+
# 3) Text generation with a single composed prompt
|
| 144 |
+
try:
|
| 145 |
+
prompt = f"[INST] {SYSTEM_PROMPT}\n\n{messages[-1]['content']} [/INST]"
|
| 146 |
+
return client.text_generation(prompt, max_new_tokens=512, temperature=0.2, top_p=0.95,
|
| 147 |
+
repetition_penalty=1.05, do_sample=True, return_full_text=False).strip()
|
| 148 |
+
except Exception as e3:
|
| 149 |
+
last = e3
|
| 150 |
+
# 4) Old conversational task
|
| 151 |
+
try:
|
| 152 |
+
conv = client.conversational(
|
| 153 |
+
past_user_inputs=[],
|
| 154 |
+
generated_responses=[],
|
| 155 |
+
text=messages[-1]["content"],
|
| 156 |
+
parameters={"temperature": 0.2, "max_new_tokens": 512},
|
| 157 |
+
)
|
| 158 |
+
return conv["generated_text"] if isinstance(conv, dict) else conv.generated_text
|
| 159 |
+
except Exception as e4:
|
| 160 |
+
last = e4
|
| 161 |
+
raise last
|
| 162 |
|
| 163 |
def answer(question: str) -> str:
|
| 164 |
if not question.strip():
|
| 165 |
return "Ask a question."
|
| 166 |
+
if (USE_FAISS and index is None) or (not USE_FAISS and matrix is None) or not doc_chunks:
|
| 167 |
return "Upload PDFs and click **Build Index** first."
|
| 168 |
|
| 169 |
ids, ctx_chunks = _retrieve(question, TOP_K)
|
| 170 |
+
contexts, files = [], []
|
|
|
|
|
|
|
|
|
|
| 171 |
for rank, i in enumerate(ids, start=1):
|
| 172 |
chunk = doc_chunks[i][:1000]
|
| 173 |
fname = doc_meta[i]["file"]
|
| 174 |
contexts.append(f"[{rank}] {fname}\n{chunk}")
|
|
|
|
| 175 |
files.append(fname)
|
| 176 |
|
| 177 |
context_str = "\n\n---\n".join(contexts)
|
| 178 |
+
messages = [
|
| 179 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 180 |
+
{"role": "user", "content": f"QUESTION: {question}\n\nCONTEXT:\n{context_str}"},
|
| 181 |
+
]
|
| 182 |
|
| 183 |
try:
|
| 184 |
+
out = _call_chat(messages)
|
| 185 |
+
unique_files = ", ".join(sorted(set(files))) if files else "N/A"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
return f"{out.strip()}\n\nSources: {unique_files}"
|
| 187 |
except Exception as e:
|
| 188 |
+
return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model availability.)"
|
| 189 |
|
| 190 |
+
# -------------------------------------------------
|
| 191 |
# UI
|
| 192 |
+
# -------------------------------------------------
|
| 193 |
with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
|
| 194 |
gr.Markdown("# 📚 PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs → Build Index → Ask questions. Answers cite sources.")
|
| 195 |
|
|
|
|
| 205 |
|
| 206 |
build_btn.click(build_from_pdfs, inputs=[files], outputs=[status])
|
| 207 |
ask_btn.click(answer, inputs=[q], outputs=[a])
|
| 208 |
+
q.submit(answer, inputs=[q], outputs=[a])
|
| 209 |
|
| 210 |
if __name__ == "__main__":
|
| 211 |
demo.launch()
|