Morinash commited on
Commit
7efb501
·
verified ·
1 Parent(s): baeb9d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -98
app.py CHANGED
@@ -1,7 +1,7 @@
 
1
  import os
2
  import tempfile
3
  import gradio as gr
4
- from typing import List
5
  import json
6
  import pandas as pd
7
  import requests
@@ -14,33 +14,36 @@ import faiss
14
  import numpy as np
15
  from transformers import pipeline
16
 
17
- # -----------------------------
18
  # CONFIG
19
- # -----------------------------
20
- HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") # You can switch later to DeepSeek
21
- EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2" # Faster, smaller
22
  INDEX_PATH = "faiss_index.index"
23
  METADATA_PATH = "metadata.json"
24
 
 
25
  # Load embedding model
 
26
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
27
 
28
- # -----------------------------
29
- # FILE HELPERS
30
- # -----------------------------
31
  def extract_text_from_pdf(file_path):
32
  reader = PdfReader(file_path)
33
- return "\n\n".join(page.extract_text() or "" for page in reader.pages)
 
34
 
35
  def extract_text_from_docx(file_path):
36
  doc = docx.Document(file_path)
37
  return "\n\n".join(p.text for p in doc.paragraphs)
38
 
39
  def extract_text_from_excel(file_path):
40
- dfs = pd.read_excel(file_path, sheet_name=None)
41
  out = []
42
- for name, df in dfs.items():
43
- out.append(f"Sheet: {name}")
44
  out.append(df.fillna("").to_csv(index=False))
45
  return "\n\n".join(out)
46
 
@@ -49,158 +52,169 @@ def extract_text_from_url(url):
49
  soup = BeautifulSoup(r.text, "lxml")
50
  for s in soup(["script", "style", "aside", "nav", "footer"]):
51
  s.decompose()
52
- return soup.get_text(separator="\n")
 
53
 
54
- # -----------------------------
55
- # CHUNKER (larger = faster)
56
- # -----------------------------
57
- splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=100)
58
 
59
- # -----------------------------
60
- # INGESTION
61
- # -----------------------------
62
  def ingest_sources(files, urls):
63
- docs, metadata = [], []
64
-
65
- if os.path.exists(INDEX_PATH) and os.path.exists(METADATA_PATH):
66
- return "Index already exists. Delete the files to re-ingest."
67
 
 
68
  for f in files:
 
69
  tmp = tempfile.NamedTemporaryFile(delete=False)
70
  try:
71
  if hasattr(f, "read"):
72
- data = f.read()
73
- if isinstance(data, str):
74
- data = data.encode("utf-8")
75
- tmp.write(data)
76
- name = getattr(f, "name", "uploaded_file")
77
- elif isinstance(f, dict) and "data" in f:
78
- data = f["data"]
79
- if isinstance(data, str):
80
- data = data.encode("utf-8")
81
- tmp.write(data)
82
- name = f.get("name", "uploaded_file")
83
- elif isinstance(f, str):
84
- tmp.write(f.encode("utf-8"))
85
- name = "uploaded_text.txt"
86
  else:
87
- tmp.close()
88
- os.unlink(tmp.name)
89
- return f"Unknown upload type: {type(f)}"
90
- finally:
91
  tmp.flush()
92
  tmp.close()
93
 
94
- try:
95
- low = name.lower()
96
- if low.endswith(".pdf"):
97
  text = extract_text_from_pdf(tmp.name)
98
- elif low.endswith(".docx"):
99
  text = extract_text_from_docx(tmp.name)
100
- elif low.endswith((".xls", ".xlsx")):
101
  text = extract_text_from_excel(tmp.name)
102
  else:
103
  with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
104
  text = fh.read()
105
- except Exception as e:
106
- print(f"Extraction error for {name}: {e}")
107
  os.unlink(tmp.name)
108
- continue
109
-
110
- os.unlink(tmp.name)
111
 
112
- for i, c in enumerate(splitter.split_text(text)):
 
113
  docs.append(c)
114
- metadata.append({"source": name, "chunk": i, "type": "file"})
115
 
116
- for u in urls or []:
117
- u = (u or "").strip()
118
- if not u:
119
  continue
120
  try:
121
  text = extract_text_from_url(u)
122
- for i, c in enumerate(splitter.split_text(text)):
 
123
  docs.append(c)
124
- metadata.append({"source": u, "chunk": i, "type": "url"})
125
  except Exception as e:
126
- print(f"URL fetch error for {u}: {e}")
127
 
128
  if not docs:
129
- return "No content ingested (empty or failed files)."
130
 
131
- try:
132
- embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
133
- except Exception as e:
134
- return f"Embedding error: {e}"
135
 
136
- try:
137
- dim = embeddings.shape[1]
 
 
 
 
 
138
  index = faiss.IndexFlatL2(dim)
139
  index.add(embeddings)
140
- faiss.write_index(index, INDEX_PATH)
141
- with open(METADATA_PATH, "w", encoding="utf-8") as fh:
142
- json.dump(metadata, fh)
143
- except Exception as e:
144
- return f"Indexing error: {e}"
145
 
146
- return f"Ingested {len(docs)} chunks from {len(files)} files and {len(urls)} URLs."
 
147
 
148
- # -----------------------------
149
- # RETRIEVAL
150
- # -----------------------------
151
  def retrieve_topk(query, k=5):
152
  if not os.path.exists(INDEX_PATH):
153
  return []
154
  q_emb = embed_model.encode([query], convert_to_numpy=True)
155
  index = faiss.read_index(INDEX_PATH)
156
  D, I = index.search(q_emb, k)
157
- metadata = json.load(open(METADATA_PATH))
158
  results = []
159
  for idx in I[0]:
160
  if idx < len(metadata):
161
  results.append(metadata[idx])
162
  return results
163
 
164
- # -----------------------------
165
- # GENERATION PIPELINE
166
- # -----------------------------
167
- gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL, device=-1)
168
-
 
 
 
 
 
 
 
169
  def ask_prompt(prompt, top_k=5):
 
 
 
170
  hits = retrieve_topk(prompt, k=top_k)
171
  if not hits:
172
- return "No documents ingested yet."
173
 
 
 
174
  sources = [f"{h['source']} (chunk {h['chunk']})" for h in hits]
175
- context = "\n\n".join(sources)
 
 
 
176
 
177
  system_instruction = (
178
- "You are a research assistant. Use the context below to answer the question clearly and briefly.\n"
 
179
  )
180
- full_prompt = f"{system_instruction}\nCONTEXT:\n{context}\n\nQUESTION:\n{prompt}\n\nAnswer:"
181
 
182
- out = gen_pipeline(full_prompt, max_length=400, do_sample=False)[0]["generated_text"]
 
 
 
 
 
 
183
  return out + "\n\nSources:\n" + "\n".join(sources)
184
 
185
- # -----------------------------
186
- # GRADIO UI
187
- # -----------------------------
188
  with gr.Blocks() as demo:
189
- gr.Markdown("# 🧠 Research Assistant (light version)\nUpload PDFs, Word, Excel, or URLs. Click **Ingest**, then ask your question.")
190
-
 
191
  with gr.Row():
192
  with gr.Column():
193
- file_in = gr.File(label="Upload files", file_count="multiple")
194
- urls_in = gr.Textbox(label="URLs (one per line)", placeholder="https://example.com")
 
 
 
 
 
195
  ingest_btn = gr.Button("Ingest")
196
  ingest_output = gr.Textbox(label="Ingest status")
197
-
198
  with gr.Column():
199
- prompt_in = gr.Textbox(label="Your question", lines=3)
200
  ask_btn = gr.Button("Ask")
201
- answer_out = gr.Textbox(label="Answer", lines=10)
202
 
203
- ingest_btn.click(lambda f, u: ingest_sources(f or [], (u or "").splitlines()), inputs=[file_in, urls_in], outputs=ingest_output)
 
 
 
 
204
  ask_btn.click(lambda p: ask_prompt(p, top_k=5), inputs=prompt_in, outputs=answer_out)
205
 
206
  if __name__ == "__main__":
 
1
+ # app.py
2
  import os
3
  import tempfile
4
  import gradio as gr
 
5
  import json
6
  import pandas as pd
7
  import requests
 
14
  import numpy as np
15
  from transformers import pipeline
16
 
17
+ # ==============================
18
  # CONFIG
19
+ # ==============================
20
+ HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large")
21
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
22
  INDEX_PATH = "faiss_index.index"
23
  METADATA_PATH = "metadata.json"
24
 
25
+ # ==============================
26
  # Load embedding model
27
+ # ==============================
28
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
29
 
30
+ # ==============================
31
+ # Helper text extractors
32
+ # ==============================
33
  def extract_text_from_pdf(file_path):
34
  reader = PdfReader(file_path)
35
+ pages = [p.extract_text() or "" for p in reader.pages]
36
+ return "\n\n".join(pages)
37
 
38
  def extract_text_from_docx(file_path):
39
  doc = docx.Document(file_path)
40
  return "\n\n".join(p.text for p in doc.paragraphs)
41
 
42
  def extract_text_from_excel(file_path):
43
+ df_dict = pd.read_excel(file_path, sheet_name=None)
44
  out = []
45
+ for sheet, df in df_dict.items():
46
+ out.append(f"Sheet: {sheet}")
47
  out.append(df.fillna("").to_csv(index=False))
48
  return "\n\n".join(out)
49
 
 
52
  soup = BeautifulSoup(r.text, "lxml")
53
  for s in soup(["script", "style", "aside", "nav", "footer"]):
54
  s.decompose()
55
+ text = soup.get_text(separator="\n")
56
+ return text
57
 
58
+ # ==============================
59
+ # Text chunking setup
60
+ # ==============================
61
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
62
 
63
+ # ==============================
64
+ # Ingestion function
65
+ # ==============================
66
  def ingest_sources(files, urls):
67
+ docs = []
68
+ metadata = []
 
 
69
 
70
+ # Handle uploaded files
71
  for f in files:
72
+ name = f.name
73
  tmp = tempfile.NamedTemporaryFile(delete=False)
74
  try:
75
  if hasattr(f, "read"):
76
+ tmp.write(f.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  else:
78
+ tmp.write(f.encode("utf-8"))
 
 
 
79
  tmp.flush()
80
  tmp.close()
81
 
82
+ if name.lower().endswith(".pdf"):
 
 
83
  text = extract_text_from_pdf(tmp.name)
84
+ elif name.lower().endswith(".docx"):
85
  text = extract_text_from_docx(tmp.name)
86
+ elif name.lower().endswith((".xls", ".xlsx")):
87
  text = extract_text_from_excel(tmp.name)
88
  else:
89
  with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
90
  text = fh.read()
91
+ finally:
 
92
  os.unlink(tmp.name)
 
 
 
93
 
94
+ chunks = splitter.split_text(text)
95
+ for i, c in enumerate(chunks):
96
  docs.append(c)
97
+ metadata.append({"source": name, "chunk": i, "type": "file", "text": c})
98
 
99
+ # Handle URLs
100
+ for u in urls:
101
+ if not u.strip():
102
  continue
103
  try:
104
  text = extract_text_from_url(u)
105
+ chunks = splitter.split_text(text)
106
+ for i, c in enumerate(chunks):
107
  docs.append(c)
108
+ metadata.append({"source": u, "chunk": i, "type": "url", "text": c})
109
  except Exception as e:
110
+ print("URL error:", u, e)
111
 
112
  if not docs:
113
+ return "No text extracted from files or URLs."
114
 
115
+ embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
116
+ dim = embeddings.shape[1]
 
 
117
 
118
+ if os.path.exists(INDEX_PATH):
119
+ index = faiss.read_index(INDEX_PATH)
120
+ old_meta = json.load(open(METADATA_PATH, "r", encoding="utf-8"))
121
+ index.add(embeddings)
122
+ old_meta.extend(metadata)
123
+ json.dump(old_meta, open(METADATA_PATH, "w", encoding="utf-8"))
124
+ else:
125
  index = faiss.IndexFlatL2(dim)
126
  index.add(embeddings)
127
+ json.dump(metadata, open(METADATA_PATH, "w", encoding="utf-8"))
 
 
 
 
128
 
129
+ faiss.write_index(index, INDEX_PATH)
130
+ return f"Ingested {len(docs)} text chunks from {len(files)} files and {len(urls)} URLs."
131
 
132
+ # ==============================
133
+ # Retrieve top matching chunks
134
+ # ==============================
135
  def retrieve_topk(query, k=5):
136
  if not os.path.exists(INDEX_PATH):
137
  return []
138
  q_emb = embed_model.encode([query], convert_to_numpy=True)
139
  index = faiss.read_index(INDEX_PATH)
140
  D, I = index.search(q_emb, k)
141
+ metadata = json.load(open(METADATA_PATH, "r", encoding="utf-8"))
142
  results = []
143
  for idx in I[0]:
144
  if idx < len(metadata):
145
  results.append(metadata[idx])
146
  return results
147
 
148
+ # ==============================
149
+ # Generation pipeline
150
+ # ==============================
151
+ gen_pipeline = pipeline(
152
+ "text2text-generation",
153
+ model=HF_GENERATION_MODEL,
154
+ device=0 if os.environ.get("HF_DEVICE", "cpu") != "cpu" else -1,
155
+ )
156
+
157
+ # ==============================
158
+ # Ask prompt
159
+ # ==============================
160
  def ask_prompt(prompt, top_k=5):
161
+ if not os.path.exists(INDEX_PATH) or not os.path.exists(METADATA_PATH):
162
+ return "No documents ingested yet."
163
+
164
  hits = retrieve_topk(prompt, k=top_k)
165
  if not hits:
166
+ return "No relevant context found. Try ingesting more content."
167
 
168
+ # Collect context text
169
+ context_parts = [h["text"] for h in hits if "text" in h]
170
  sources = [f"{h['source']} (chunk {h['chunk']})" for h in hits]
171
+
172
+ context = "\n\n".join(context_parts)
173
+ if not context.strip():
174
+ return "No readable text found in the ingested files."
175
 
176
  system_instruction = (
177
+ "You are a helpful research assistant. Read the provided context carefully "
178
+ "and answer the question accurately and concisely."
179
  )
 
180
 
181
+ full_prompt = f"{system_instruction}\n\nCONTEXT:\n{context}\n\nQUESTION:\n{prompt}\n\nAnswer:"
182
+
183
+ try:
184
+ out = gen_pipeline(full_prompt, max_length=400, do_sample=False)[0]["generated_text"]
185
+ except Exception as e:
186
+ return f"Model generation failed: {e}"
187
+
188
  return out + "\n\nSources:\n" + "\n".join(sources)
189
 
190
+ # ==============================
191
+ # Gradio UI
192
+ # ==============================
193
  with gr.Blocks() as demo:
194
+ gr.Markdown(
195
+ "# 🧠 Research Assistant (Prototype)\nUpload files or paste URLs, click **Ingest**, then ask your question."
196
+ )
197
  with gr.Row():
198
  with gr.Column():
199
+ file_in = gr.File(
200
+ label="Upload files (pdf/docx/xlsx/txt)", file_count="multiple"
201
+ )
202
+ urls_in = gr.Textbox(
203
+ label="URLs (one per line)",
204
+ placeholder="https://example.com/article",
205
+ )
206
  ingest_btn = gr.Button("Ingest")
207
  ingest_output = gr.Textbox(label="Ingest status")
 
208
  with gr.Column():
209
+ prompt_in = gr.Textbox(label="Your question", lines=4)
210
  ask_btn = gr.Button("Ask")
211
+ answer_out = gr.Textbox(label="Answer", lines=12)
212
 
213
+ ingest_btn.click(
214
+ lambda files, urls: ingest_sources(files or [], (urls or "").splitlines()),
215
+ inputs=[file_in, urls_in],
216
+ outputs=ingest_output,
217
+ )
218
  ask_btn.click(lambda p: ask_prompt(p, top_k=5), inputs=prompt_in, outputs=answer_out)
219
 
220
  if __name__ == "__main__":