Morinash commited on
Commit
6bb5c19
·
verified ·
1 Parent(s): 5811f5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -86
app.py CHANGED
@@ -15,32 +15,28 @@ import numpy as np
15
  from transformers import pipeline
16
 
17
  # CONFIG
18
- HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") # change to DeepSeek if ready
19
- EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
20
  INDEX_PATH = "faiss_index.index"
21
  METADATA_PATH = "metadata.json"
22
 
23
- # load embedding model
24
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
25
 
26
- # helper extractors
27
- def extract_text_from_pdf(file):
28
- reader = PdfReader(file)
29
- pages = []
30
- for p in reader.pages:
31
- text = p.extract_text() or ""
32
- pages.append(text)
33
- return "\n\n".join(pages)
34
-
35
- def extract_text_from_docx(file):
36
- doc = docx.Document(file)
37
  return "\n\n".join(p.text for p in doc.paragraphs)
38
 
39
- def extract_text_from_excel(file):
40
- df_dict = pd.read_excel(file, sheet_name=None)
41
  out = []
42
- for sheet, df in df_dict.items():
43
- out.append(f"Sheet: {sheet}")
44
  out.append(df.fillna("").to_csv(index=False))
45
  return "\n\n".join(out)
46
 
@@ -49,37 +45,30 @@ def extract_text_from_url(url):
49
  soup = BeautifulSoup(r.text, "lxml")
50
  for s in soup(["script", "style", "aside", "nav", "footer"]):
51
  s.decompose()
52
- text = soup.get_text(separator="\n")
53
- return text
54
 
55
- # chunker
56
- splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
57
 
 
58
  def ingest_sources(files, urls):
59
- docs = []
60
- metadata = []
 
 
 
61
 
62
  for f in files:
63
- # make sure we have a temp file
64
  tmp = tempfile.NamedTemporaryFile(delete=False)
65
-
66
- # handle different types of file objects
67
- if hasattr(f, "read"): # normal file
68
  tmp.write(f.read())
69
- name = getattr(f, "name", "uploaded_file")
70
- elif isinstance(f, str): # NamedString or text
71
- tmp.write(f.encode("utf-8"))
72
- name = "uploaded_text.txt"
73
- elif isinstance(f, dict) and "data" in f: # HF file dict
74
- tmp.write(f["data"])
75
- name = f.get("name", "uploaded_file")
76
  else:
77
- raise ValueError(f"Unknown file type: {type(f)}")
78
-
79
  tmp.flush()
80
  tmp.close()
81
 
82
- # extract text depending on file type
 
83
  if name.lower().endswith(".pdf"):
84
  text = extract_text_from_pdf(tmp.name)
85
  elif name.lower().endswith(".docx"):
@@ -89,99 +78,83 @@ def ingest_sources(files, urls):
89
  else:
90
  with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
91
  text = fh.read()
92
-
93
  os.unlink(tmp.name)
94
 
95
- chunks = splitter.split_text(text)
96
- for i, c in enumerate(chunks):
97
  docs.append(c)
98
  metadata.append({"source": name, "chunk": i, "type": "file"})
99
 
100
- # handle URLs
101
  for u in urls:
102
- u = u.strip()
103
- if not u:
104
- continue
105
  try:
106
  text = extract_text_from_url(u)
107
- chunks = splitter.split_text(text)
108
- for i, c in enumerate(chunks):
109
  docs.append(c)
110
  metadata.append({"source": u, "chunk": i, "type": "url"})
111
  except Exception as e:
112
- print("url error", u, e)
113
 
114
  if not docs:
115
- return "No valid documents or URLs found."
116
 
117
  embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
118
- dim = embeddings.shape[1]
119
-
120
- if os.path.exists(INDEX_PATH):
121
- index = faiss.read_index(INDEX_PATH)
122
- old_meta = json.load(open(METADATA_PATH, "r"))
123
- index.add(embeddings)
124
- old_meta.extend(metadata)
125
- json.dump(old_meta, open(METADATA_PATH, "w"))
126
- else:
127
- index = faiss.IndexFlatL2(dim)
128
- index.add(embeddings)
129
- json.dump(metadata, open(METADATA_PATH, "w"))
130
 
131
  faiss.write_index(index, INDEX_PATH)
132
- return f"Ingested {len(docs)} chunks from {len(files)} files and {len(urls)} urls."
 
 
133
 
 
134
  def retrieve_topk(query, k=5):
135
- q_emb = embed_model.encode([query], convert_to_numpy=True)
136
  if not os.path.exists(INDEX_PATH):
137
  return []
 
138
  index = faiss.read_index(INDEX_PATH)
139
  D, I = index.search(q_emb, k)
140
- metadata = json.load(open(METADATA_PATH, "r"))
141
  results = []
142
  for idx in I[0]:
143
  if idx < len(metadata):
144
- results.append((metadata[idx], idx))
145
  return results
146
 
147
- gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL, device=0 if os.environ.get("HF_DEVICE", "cpu") != "cpu" else -1)
 
148
 
149
  def ask_prompt(prompt, top_k=5):
150
  hits = retrieve_topk(prompt, k=top_k)
151
  if not hits:
152
- return "No documents ingested. Use Ingest first."
153
 
154
- context_parts = []
155
- sources = []
156
- for meta, idx in hits:
157
- sources.append(f"{meta['source']} (chunk {meta['chunk']})")
158
- context_parts.append(f"[{meta['source']} - chunk {meta['chunk']}]")
159
 
160
- context = "\n\n".join(context_parts)
161
  system_instruction = (
162
- "You are an AI research assistant. Use the contextual chunks below to answer the user's question. "
163
- "Provide a concise answer, then list sources in order of relevance."
164
  )
 
165
 
166
- prompt_text = f"{system_instruction}\n\nCONTEXT:\n{context}\n\nQUESTION:\n{prompt}\n\nAnswer:"
167
- out = gen_pipeline(prompt_text, max_length=512, do_sample=False)[0]["generated_text"]
168
- out = out + "\n\nSources:\n" + "\n".join(sources)
169
- return out
170
 
171
- # Gradio UI
172
  with gr.Blocks() as demo:
173
- gr.Markdown("# Research Assistant (prototype)\nUpload files and/or provide URLs, click Ingest, then Ask a question.")
 
174
  with gr.Row():
175
  with gr.Column():
176
- file_in = gr.File(label="Upload files (pdf/docx/xlsx/txt)", file_count="multiple")
177
- urls_in = gr.Textbox(label="URLs (one per line)", placeholder="https://example.com/article")
178
  ingest_btn = gr.Button("Ingest")
179
  ingest_output = gr.Textbox(label="Ingest status")
 
180
  with gr.Column():
181
- prompt_in = gr.Textbox(label="Your question", lines=4)
182
  ask_btn = gr.Button("Ask")
183
- answer_out = gr.Textbox(label="Answer", lines=12)
184
- ingest_btn.click(lambda files, urls: ingest_sources(files or [], (urls or "").splitlines()), inputs=[file_in, urls_in], outputs=ingest_output)
 
185
  ask_btn.click(lambda p: ask_prompt(p, top_k=5), inputs=prompt_in, outputs=answer_out)
186
 
187
  if __name__ == "__main__":
 
15
  from transformers import pipeline
16
 
17
  # CONFIG
18
+ HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") # or another HF model
19
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2" # smaller + faster
20
  INDEX_PATH = "faiss_index.index"
21
  METADATA_PATH = "metadata.json"
22
 
23
+ # Load embedding model (small + CPU efficient)
24
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
25
 
26
+ # --- Helpers ---
27
+ def extract_text_from_pdf(file_path):
28
+ reader = PdfReader(file_path)
29
+ return "\n\n".join(page.extract_text() or "" for page in reader.pages)
30
+
31
+ def extract_text_from_docx(file_path):
32
+ doc = docx.Document(file_path)
 
 
 
 
33
  return "\n\n".join(p.text for p in doc.paragraphs)
34
 
35
+ def extract_text_from_excel(file_path):
36
+ dfs = pd.read_excel(file_path, sheet_name=None)
37
  out = []
38
+ for name, df in dfs.items():
39
+ out.append(f"Sheet: {name}")
40
  out.append(df.fillna("").to_csv(index=False))
41
  return "\n\n".join(out)
42
 
 
45
  soup = BeautifulSoup(r.text, "lxml")
46
  for s in soup(["script", "style", "aside", "nav", "footer"]):
47
  s.decompose()
48
+ return soup.get_text(separator="\n")
 
49
 
50
+ # --- Chunker (larger chunks = fewer embeddings) ---
51
+ splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=100)
52
 
53
+ # --- Ingest sources ---
54
  def ingest_sources(files, urls):
55
+ docs, metadata = [], []
56
+
57
+ # Skip if already indexed
58
+ if os.path.exists(INDEX_PATH) and os.path.exists(METADATA_PATH):
59
+ return "Already have an index. Delete existing files to re-ingest."
60
 
61
  for f in files:
 
62
  tmp = tempfile.NamedTemporaryFile(delete=False)
63
+ if hasattr(f, "read"):
 
 
64
  tmp.write(f.read())
 
 
 
 
 
 
 
65
  else:
66
+ tmp.write(f.encode("utf-8"))
 
67
  tmp.flush()
68
  tmp.close()
69
 
70
+ name = getattr(f, "name", "uploaded_file")
71
+
72
  if name.lower().endswith(".pdf"):
73
  text = extract_text_from_pdf(tmp.name)
74
  elif name.lower().endswith(".docx"):
 
78
  else:
79
  with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
80
  text = fh.read()
 
81
  os.unlink(tmp.name)
82
 
83
+ for i, c in enumerate(splitter.split_text(text)):
 
84
  docs.append(c)
85
  metadata.append({"source": name, "chunk": i, "type": "file"})
86
 
 
87
  for u in urls:
 
 
 
88
  try:
89
  text = extract_text_from_url(u)
90
+ for i, c in enumerate(splitter.split_text(text)):
 
91
  docs.append(c)
92
  metadata.append({"source": u, "chunk": i, "type": "url"})
93
  except Exception as e:
94
+ print(f"URL error for {u}: {e}")
95
 
96
  if not docs:
97
+ return "No valid content found."
98
 
99
  embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
100
+ index = faiss.IndexFlatL2(embeddings.shape[1])
101
+ index.add(embeddings)
 
 
 
 
 
 
 
 
 
 
102
 
103
  faiss.write_index(index, INDEX_PATH)
104
+ json.dump(metadata, open(METADATA_PATH, "w"))
105
+
106
+ return f"Ingested {len(docs)} text chunks."
107
 
108
+ # --- Retrieval ---
109
  def retrieve_topk(query, k=5):
 
110
  if not os.path.exists(INDEX_PATH):
111
  return []
112
+ q_emb = embed_model.encode([query], convert_to_numpy=True)
113
  index = faiss.read_index(INDEX_PATH)
114
  D, I = index.search(q_emb, k)
115
+ metadata = json.load(open(METADATA_PATH))
116
  results = []
117
  for idx in I[0]:
118
  if idx < len(metadata):
119
+ results.append(metadata[idx])
120
  return results
121
 
122
+ # --- QA pipeline ---
123
+ gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL)
124
 
125
  def ask_prompt(prompt, top_k=5):
126
  hits = retrieve_topk(prompt, k=top_k)
127
  if not hits:
128
+ return "No documents ingested yet."
129
 
130
+ sources = [f"{h['source']} (chunk {h['chunk']})" for h in hits]
131
+ context = "\n\n".join(sources)
 
 
 
132
 
 
133
  system_instruction = (
134
+ "You are a research assistant. Use the context below to answer the question clearly and briefly.\n"
 
135
  )
136
+ full_prompt = f"{system_instruction}\nCONTEXT:\n{context}\n\nQUESTION:\n{prompt}\n\nAnswer:"
137
 
138
+ out = gen_pipeline(full_prompt, max_length=400, do_sample=False)[0]["generated_text"]
139
+ return out + "\n\nSources:\n" + "\n".join(sources)
 
 
140
 
141
+ # --- Gradio UI ---
142
  with gr.Blocks() as demo:
143
+ gr.Markdown("# 🧠 Research Assistant (light version)\nUpload PDFs, Docs, Excel, or URLs. Click **Ingest**, then ask your question.")
144
+
145
  with gr.Row():
146
  with gr.Column():
147
+ file_in = gr.File(label="Upload files", file_count="multiple")
148
+ urls_in = gr.Textbox(label="URLs (one per line)", placeholder="https://example.com")
149
  ingest_btn = gr.Button("Ingest")
150
  ingest_output = gr.Textbox(label="Ingest status")
151
+
152
  with gr.Column():
153
+ prompt_in = gr.Textbox(label="Your question", lines=3)
154
  ask_btn = gr.Button("Ask")
155
+ answer_out = gr.Textbox(label="Answer", lines=10)
156
+
157
+ ingest_btn.click(lambda f, u: ingest_sources(f or [], (u or "").splitlines()), inputs=[file_in, urls_in], outputs=ingest_output)
158
  ask_btn.click(lambda p: ask_prompt(p, top_k=5), inputs=prompt_in, outputs=answer_out)
159
 
160
  if __name__ == "__main__":