Morinash commited on
Commit
baeb9d2
·
verified ·
1 Parent(s): a708a4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -44
app.py CHANGED
@@ -14,16 +14,20 @@ import faiss
14
  import numpy as np
15
  from transformers import pipeline
16
 
 
17
  # CONFIG
18
- HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") # or another HF model
19
- EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2" # smaller + faster
 
20
  INDEX_PATH = "faiss_index.index"
21
  METADATA_PATH = "metadata.json"
22
 
23
- # Load embedding model (small + CPU efficient)
24
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
25
 
26
- # --- Helpers ---
 
 
27
  def extract_text_from_pdf(file_path):
28
  reader = PdfReader(file_path)
29
  return "\n\n".join(page.extract_text() or "" for page in reader.pages)
@@ -47,65 +51,103 @@ def extract_text_from_url(url):
47
  s.decompose()
48
  return soup.get_text(separator="\n")
49
 
50
- # --- Chunker (larger chunks = fewer embeddings) ---
 
 
51
  splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=100)
52
 
53
- # --- Ingest sources ---
 
 
54
  def ingest_sources(files, urls):
55
  docs, metadata = [], []
56
 
57
- # Skip if already indexed
58
  if os.path.exists(INDEX_PATH) and os.path.exists(METADATA_PATH):
59
- return "Already have an index. Delete existing files to re-ingest."
60
 
61
  for f in files:
62
  tmp = tempfile.NamedTemporaryFile(delete=False)
63
- if hasattr(f, "read"):
64
- tmp.write(f.read())
65
- else:
66
- tmp.write(f.encode("utf-8"))
67
- tmp.flush()
68
- tmp.close()
69
-
70
- name = getattr(f, "name", "uploaded_file")
71
-
72
- if name.lower().endswith(".pdf"):
73
- text = extract_text_from_pdf(tmp.name)
74
- elif name.lower().endswith(".docx"):
75
- text = extract_text_from_docx(tmp.name)
76
- elif name.lower().endswith((".xls", ".xlsx")):
77
- text = extract_text_from_excel(tmp.name)
78
- else:
79
- with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
80
- text = fh.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  os.unlink(tmp.name)
82
 
83
  for i, c in enumerate(splitter.split_text(text)):
84
  docs.append(c)
85
  metadata.append({"source": name, "chunk": i, "type": "file"})
86
 
87
- for u in urls:
 
 
 
88
  try:
89
  text = extract_text_from_url(u)
90
  for i, c in enumerate(splitter.split_text(text)):
91
  docs.append(c)
92
  metadata.append({"source": u, "chunk": i, "type": "url"})
93
  except Exception as e:
94
- print(f"URL error for {u}: {e}")
95
 
96
  if not docs:
97
- return "No valid content found."
98
-
99
- embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
100
- index = faiss.IndexFlatL2(embeddings.shape[1])
101
- index.add(embeddings)
102
-
103
- faiss.write_index(index, INDEX_PATH)
104
- json.dump(metadata, open(METADATA_PATH, "w"))
105
-
106
- return f"Ingested {len(docs)} text chunks."
107
-
108
- # --- Retrieval ---
 
 
 
 
 
 
 
 
 
 
109
  def retrieve_topk(query, k=5):
110
  if not os.path.exists(INDEX_PATH):
111
  return []
@@ -119,8 +161,10 @@ def retrieve_topk(query, k=5):
119
  results.append(metadata[idx])
120
  return results
121
 
122
- # --- QA pipeline ---
123
- gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL)
 
 
124
 
125
  def ask_prompt(prompt, top_k=5):
126
  hits = retrieve_topk(prompt, k=top_k)
@@ -138,9 +182,11 @@ def ask_prompt(prompt, top_k=5):
138
  out = gen_pipeline(full_prompt, max_length=400, do_sample=False)[0]["generated_text"]
139
  return out + "\n\nSources:\n" + "\n".join(sources)
140
 
141
- # --- Gradio UI ---
 
 
142
  with gr.Blocks() as demo:
143
- gr.Markdown("# 🧠 Research Assistant (light version)\nUpload PDFs, Docs, Excel, or URLs. Click **Ingest**, then ask your question.")
144
 
145
  with gr.Row():
146
  with gr.Column():
 
14
  import numpy as np
15
  from transformers import pipeline
16
 
17
+ # -----------------------------
18
  # CONFIG
19
+ # -----------------------------
20
+ HF_GENERATION_MODEL = os.environ.get("HF_GENERATION_MODEL", "google/flan-t5-large") # You can switch later to DeepSeek
21
+ EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L3-v2" # Faster, smaller
22
  INDEX_PATH = "faiss_index.index"
23
  METADATA_PATH = "metadata.json"
24
 
25
+ # Load embedding model
26
  embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
27
 
28
+ # -----------------------------
29
+ # FILE HELPERS
30
+ # -----------------------------
31
  def extract_text_from_pdf(file_path):
32
  reader = PdfReader(file_path)
33
  return "\n\n".join(page.extract_text() or "" for page in reader.pages)
 
51
  s.decompose()
52
  return soup.get_text(separator="\n")
53
 
54
+ # -----------------------------
55
+ # CHUNKER (larger = faster)
56
+ # -----------------------------
57
  splitter = RecursiveCharacterTextSplitter(chunk_size=3000, chunk_overlap=100)
58
 
59
+ # -----------------------------
60
+ # INGESTION
61
+ # -----------------------------
62
  def ingest_sources(files, urls):
63
  docs, metadata = [], []
64
 
 
65
  if os.path.exists(INDEX_PATH) and os.path.exists(METADATA_PATH):
66
+ return "Index already exists. Delete the files to re-ingest."
67
 
68
  for f in files:
69
  tmp = tempfile.NamedTemporaryFile(delete=False)
70
+ try:
71
+ if hasattr(f, "read"):
72
+ data = f.read()
73
+ if isinstance(data, str):
74
+ data = data.encode("utf-8")
75
+ tmp.write(data)
76
+ name = getattr(f, "name", "uploaded_file")
77
+ elif isinstance(f, dict) and "data" in f:
78
+ data = f["data"]
79
+ if isinstance(data, str):
80
+ data = data.encode("utf-8")
81
+ tmp.write(data)
82
+ name = f.get("name", "uploaded_file")
83
+ elif isinstance(f, str):
84
+ tmp.write(f.encode("utf-8"))
85
+ name = "uploaded_text.txt"
86
+ else:
87
+ tmp.close()
88
+ os.unlink(tmp.name)
89
+ return f"Unknown upload type: {type(f)}"
90
+ finally:
91
+ tmp.flush()
92
+ tmp.close()
93
+
94
+ try:
95
+ low = name.lower()
96
+ if low.endswith(".pdf"):
97
+ text = extract_text_from_pdf(tmp.name)
98
+ elif low.endswith(".docx"):
99
+ text = extract_text_from_docx(tmp.name)
100
+ elif low.endswith((".xls", ".xlsx")):
101
+ text = extract_text_from_excel(tmp.name)
102
+ else:
103
+ with open(tmp.name, "r", encoding="utf-8", errors="ignore") as fh:
104
+ text = fh.read()
105
+ except Exception as e:
106
+ print(f"Extraction error for {name}: {e}")
107
+ os.unlink(tmp.name)
108
+ continue
109
+
110
  os.unlink(tmp.name)
111
 
112
  for i, c in enumerate(splitter.split_text(text)):
113
  docs.append(c)
114
  metadata.append({"source": name, "chunk": i, "type": "file"})
115
 
116
+ for u in urls or []:
117
+ u = (u or "").strip()
118
+ if not u:
119
+ continue
120
  try:
121
  text = extract_text_from_url(u)
122
  for i, c in enumerate(splitter.split_text(text)):
123
  docs.append(c)
124
  metadata.append({"source": u, "chunk": i, "type": "url"})
125
  except Exception as e:
126
+ print(f"URL fetch error for {u}: {e}")
127
 
128
  if not docs:
129
+ return "No content ingested (empty or failed files)."
130
+
131
+ try:
132
+ embeddings = embed_model.encode(docs, show_progress_bar=True, convert_to_numpy=True)
133
+ except Exception as e:
134
+ return f"Embedding error: {e}"
135
+
136
+ try:
137
+ dim = embeddings.shape[1]
138
+ index = faiss.IndexFlatL2(dim)
139
+ index.add(embeddings)
140
+ faiss.write_index(index, INDEX_PATH)
141
+ with open(METADATA_PATH, "w", encoding="utf-8") as fh:
142
+ json.dump(metadata, fh)
143
+ except Exception as e:
144
+ return f"Indexing error: {e}"
145
+
146
+ return f"Ingested {len(docs)} chunks from {len(files)} files and {len(urls)} URLs."
147
+
148
+ # -----------------------------
149
+ # RETRIEVAL
150
+ # -----------------------------
151
  def retrieve_topk(query, k=5):
152
  if not os.path.exists(INDEX_PATH):
153
  return []
 
161
  results.append(metadata[idx])
162
  return results
163
 
164
+ # -----------------------------
165
+ # GENERATION PIPELINE
166
+ # -----------------------------
167
+ gen_pipeline = pipeline("text2text-generation", model=HF_GENERATION_MODEL, device=-1)
168
 
169
  def ask_prompt(prompt, top_k=5):
170
  hits = retrieve_topk(prompt, k=top_k)
 
182
  out = gen_pipeline(full_prompt, max_length=400, do_sample=False)[0]["generated_text"]
183
  return out + "\n\nSources:\n" + "\n".join(sources)
184
 
185
+ # -----------------------------
186
+ # GRADIO UI
187
+ # -----------------------------
188
  with gr.Blocks() as demo:
189
+ gr.Markdown("# 🧠 Research Assistant (light version)\nUpload PDFs, Word, Excel, or URLs. Click **Ingest**, then ask your question.")
190
 
191
  with gr.Row():
192
  with gr.Column():