aach456 commited on
Commit
7c22e3c
·
verified ·
1 Parent(s): fc96ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -238
app.py CHANGED
@@ -1,46 +1,12 @@
1
- # app.py
2
- # Chat-style RAG app with Streamlit chat UI, FAISS retrieval, SentenceTransformers embeddings,
3
- # and an open Mistral-7B pipeline. All caches redirected to /tmp to avoid PermissionError.
4
-
5
- # ---------- Writable dirs BEFORE third-party imports ----------
6
- import os, glob, tempfile
7
- # Streamlit internal runtime dir -> /tmp (fixes PermissionError: '/.streamlit')
8
- ST_RT = os.environ.get("STREAMLIT_RUNTIME_DIR", "/tmp/.streamlit_runtime")
9
- try:
10
- os.makedirs(ST_RT, exist_ok=True)
11
- except Exception:
12
- ST_RT = tempfile.mkdtemp(prefix="st_runtime_")
13
- os.environ["STREAMLIT_RUNTIME_DIR"] = ST_RT
14
-
15
- # Hugging Face caches -> /tmp
16
- HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache")
17
- try:
18
- os.makedirs(HF_HOME, exist_ok=True)
19
- except Exception:
20
- HF_HOME = tempfile.mkdtemp(prefix="hf_cache_")
21
- os.environ["HF_HOME"] = HF_HOME
22
- os.environ["TRANSFORMERS_CACHE"] = HF_HOME # backward-compat; deprecation warning is harmless
23
- os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME
24
- os.environ["HF_DATASETS_CACHE"] = HF_HOME
25
- os.environ["XDG_CACHE_HOME"] = HF_HOME
26
- os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
27
-
28
- # Clean stale locks
29
- locks_dir = os.path.join(HF_HOME, "hub", ".locks")
30
- if os.path.isdir(locks_dir):
31
- for p in glob.glob(os.path.join(locks_dir, "*.lock")):
32
- try:
33
- os.remove(p)
34
- except Exception:
35
- pass
36
-
37
- # ---------- Imports AFTER env is set ----------
38
  import io
 
 
39
  import time
40
- import pandas as pd
41
  import numpy as np
 
42
  import requests
43
- import streamlit as st
44
  from bs4 import BeautifulSoup
45
  from PyPDF2 import PdfReader
46
  from docx import Document
@@ -49,222 +15,161 @@ from sentence_transformers import SentenceTransformer
49
  from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
50
  import faiss
51
 
52
- # ---------- Page ----------
53
- st.set_page_config(page_title="Chat RAG • Open Model + URLs", layout="wide")
54
- st.title("💬 Chat RAG with Open Model, FAISS, and Web URLs")
55
-
56
- # ---------- Session ----------
57
- for key, default in [
58
- ("messages", []),
59
- ("chunks", []),
60
- ("embedder", None),
61
- ("faiss_index", None),
62
- ]:
63
- if key not in st.session_state:
64
- st.session_state[key] = default
65
-
66
- # ---------- Loaders ----------
67
- def load_txt(file):
68
- raw = file.read()
69
- for enc in ("utf-8", "latin-1"):
70
- try:
71
- return [{"source": file.name, "text": raw.decode(enc, errors="ignore")}]
72
- except Exception:
73
- continue
74
- return [{"source": file.name, "text": raw.decode("utf-8", errors="ignore")}]
75
-
76
- def load_pdf(file):
77
- pdf = PdfReader(file)
78
- text = ""
79
- for page in pdf.pages:
80
- text += page.extract_text() or ""
81
- return [{"source": file.name, "text": text}]
82
-
83
- def load_docx(file):
84
- data = file.read()
85
- doc = Document(io.BytesIO(data))
86
- text = " ".join(p.text for p in doc.paragraphs)
87
- return [{"source": file.name, "text": text}]
88
 
89
- def load_csv(file):
90
- data = file.read()
91
- df = None
92
- for enc in ("utf-8", "latin-1"):
93
- try:
94
- df = pd.read_csv(io.BytesIO(data), encoding=enc)
95
- break
96
- except Exception:
97
- df = None
98
- if df is None:
99
- try:
100
- df = pd.read_csv(io.BytesIO(data), engine="python")
101
- except Exception:
102
- df = pd.DataFrame()
103
- text = " ".join(df.astype(str).values.flatten().tolist()) if not df.empty else ""
104
- return [{"source": file.name, "text": text}]
105
 
106
- def load_documents(files):
107
- docs = []
108
- for file in files or []:
109
- name = file.name.lower()
110
- if name.endswith(".pdf"):
111
- docs += load_pdf(file)
112
- elif name.endswith(".docx"):
113
- docs += load_docx(file)
114
- elif name.endswith(".csv"):
115
- docs += load_csv(file)
116
- elif name.endswith(".txt"):
117
- docs += load_txt(file)
118
- return docs
119
 
120
- # ---------- Web fetch ----------
121
- def fetch_web_text(url, timeout=12, retries=2, backoff=1.5):
122
- for attempt in range(retries + 1):
123
- try:
124
- headers = {
125
- "User-Agent": (
126
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
127
- "AppleWebKit/537.36 (KHTML, like Gecko) "
128
- "Chrome/124.0 Safari/537.36"
129
- )
130
- }
131
- resp = requests.get(url, headers=headers, timeout=timeout)
132
- resp.raise_for_status()
133
- soup = BeautifulSoup(resp.text, "html.parser")
134
- for tag in soup(["script", "style", "noscript"]):
135
- tag.decompose()
136
- text = " ".join(soup.get_text(separator=" ").split())
137
- return [{"source": url, "text": text}]
138
- except Exception:
139
- if attempt < retries:
140
- time.sleep(backoff ** attempt)
141
- else:
142
- return []
143
 
144
- # ---------- Chunking ----------
145
- def chunk_documents(docs, chunk_size=1000, chunk_overlap=120):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
147
  chunks = []
148
  for doc in docs:
149
- splits = splitter.split_text(doc.get("text", "") or "")
150
  for idx, chunk in enumerate(splits):
151
  chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk})
152
  return chunks
153
 
154
- # ---------- Embeddings / Index ----------
155
- @st.cache_resource(show_spinner=False)
156
- def load_embedder():
157
- return SentenceTransformer("all-MiniLM-L6-v2", cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME", HF_HOME))
158
-
159
- def build_embeddings_index(chunks):
160
- embedder = load_embedder()
161
- texts = [c["content"] for c in chunks]
162
- if not texts:
163
- return embedder, None
164
- emb = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True)
165
- emb = np.asarray(emb, dtype="float32")
166
- idx = faiss.IndexFlatL2(emb.shape[14])
167
- idx.add(emb)
168
- return embedder, idx
169
-
170
- def retrieve(query, embedder, index, chunks, top_k=4):
171
- if index is None or not chunks:
172
  return []
173
  q_emb = embedder.encode([query], convert_to_numpy=True)
174
- q_emb = np.asarray(q_emb, dtype="float32")
175
  distances, indices = index.search(q_emb, top_k)
176
- out = []
177
- for pos, i in enumerate(indices):
178
- if i >= 0 and i < len(chunks):
179
- out.append({"chunk": chunks[i], "score": float(distances[pos])})
180
- return out
181
-
182
- # ---------- LLM ----------
183
- MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2"
184
-
185
- @st.cache_resource(show_spinner=False)
186
- def load_llm():
187
- cache_dir = os.environ.get("HF_HOME", HF_HOME)
188
- _ = AutoConfig.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
189
- tok = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
190
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
191
- return pipeline("text-generation", model=model, tokenizer=tok, max_length=1024, do_sample=True, temperature=0.2, trust_remote_code=True, device_map="auto")
192
-
193
- def answer_with_llm(context_chunks, query, llm):
194
  context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks)
195
  prompt = (
196
  "Answer the following question using ONLY the provided context and cite the chunk ids used.\n"
197
- f"Question: {query}\n"
198
- "Context:\n"
199
- f"{context_text}\n"
200
- "Answer with citations:"
201
  )
202
- out = llm(prompt, max_length=512, num_return_sequences=1)
203
- return out["generated_text"]
204
-
205
- # ---------- Sidebar sources ----------
206
- st.sidebar.header("Data sources")
207
-
208
- uploaded_files = st.sidebar.file_uploader(
209
- "Upload documents (PDF, DOCX, TXT, CSV)",
210
- type=["pdf", "txt", "docx", "csv"],
211
- accept_multiple_files=True,
212
- help="Default per-file limit ~200MB; increase via .streamlit/config.toml if needed.",
213
- )
214
- with st.sidebar.expander("Upload debug"):
215
- info = {
216
- "type": type(uploaded_files).__name__,
217
- "num_files": (len(uploaded_files) if isinstance(uploaded_files, list) else (1 if uploaded_files else 0)),
218
- "names": ([f.name for f in uploaded_files] if isinstance(uploaded_files, list) else ([uploaded_files.name] if uploaded_files else [])),
219
- }
220
- st.write(info)
221
-
222
- url_input = st.sidebar.text_area("Web URLs (one per line)", value="", height=120)
223
-
224
- web_docs = []
225
- if url_input.strip():
226
- urls = [u.strip() for u in url_input.splitlines() if u.strip()]
227
- with st.sidebar.spinner("Fetching web content..."):
228
- for u in urls:
229
- web_docs += fetch_web_text(u)
230
 
231
- file_docs = load_documents(uploaded_files) if uploaded_files else []
232
- all_docs = file_docs + web_docs
233
 
234
- if all_docs:
235
- st.success(f"{len(all_docs)} document(s) loaded from files and URLs.")
236
- with st.spinner("Chunking and embedding..."):
237
- st.session_state.chunks = chunk_documents(all_docs, chunk_size=1000, chunk_overlap=120)
238
- st.session_state.embedder, st.session_state.faiss_index = build_embeddings_index(st.session_state.chunks)
239
- st.write(f"{len(st.session_state.chunks)} chunks created and indexed.")
240
- else:
241
- st.info("Add documents or URLs in the sidebar to start.")
242
-
243
- # ---------- Chat UI ----------
244
- for m in st.session_state.messages:
245
- with st.chat_message(m["role"]):
246
- st.markdown(m["content"])
247
-
248
- user_input = st.chat_input("Ask about the loaded documents...")
249
- if user_input:
250
- st.session_state.messages.append({"role": "user", "content": user_input})
251
- with st.chat_message("user"):
252
- st.markdown(user_input)
253
-
254
- with st.chat_message("assistant"):
255
- with st.spinner("Thinking..."):
256
- if st.session_state.chunks:
257
- llm = load_llm()
258
- results = retrieve(user_input, st.session_state.embedder, st.session_state.faiss_index, st.session_state.chunks, top_k=4)
259
- context_chunks = [r["chunk"] for r in results]
260
- answer = answer_with_llm(context_chunks, user_input, llm)
261
- st.markdown(answer)
262
- sources = "\n".join(f"[{r['chunk']['chunk_id']} from {r['chunk']['source']}]" for r in results) or "No sources (no matches)."
263
- with st.expander("Sources"):
264
- st.code(sources)
265
- else:
266
- answer = "No documents indexed yet. Add files or URLs in the sidebar and try again."
267
- st.warning(answer)
268
- st.session_state.messages.append({"role": "assistant", "content": answer})
269
-
270
- st.caption("Chat RAG • Mistral-7B (open), FAISS, SentenceTransformers, and Web URLs • Streamlit chat UI")
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import io
3
+ import glob
4
+ import tempfile
5
  import time
 
6
  import numpy as np
7
+ import pandas as pd
8
  import requests
9
+ import gradio as gr
10
  from bs4 import BeautifulSoup
11
  from PyPDF2 import PdfReader
12
  from docx import Document
 
15
  from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
16
  import faiss
17
 
18
+ # Setup HF cache paths before imports
19
+ HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache")
20
+ os.makedirs(HF_HOME, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ os.environ["HF_HOME"] = HF_HOME
23
+ os.environ["TRANSFORMERS_CACHE"] = HF_HOME
24
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME
25
+ os.environ["HF_DATASETS_CACHE"] = HF_HOME
26
+ os.environ["XDG_CACHE_HOME"] = HF_HOME
27
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
 
 
 
 
 
 
 
 
 
 
28
 
29
+ locks_dir = os.path.join(HF_HOME, "hub", ".locks")
30
+ if os.path.isdir(locks_dir):
31
+ for p in glob.glob(os.path.join(locks_dir, "*.lock")):
32
+ try: os.remove(p)
33
+ except: pass
 
 
 
 
 
 
 
 
34
 
35
+ MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=HF_HOME)
38
+ config = AutoConfig.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
40
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
41
+ llm = pipeline("text-generation", model=model, tokenizer=tokenizer,
42
+ max_length=1024, do_sample=True, temperature=0.2,
43
+ trust_remote_code=True, device_map="auto")
44
+
45
+ def load_file_text(file):
46
+ name = file.name.lower()
47
+ if name.endswith(".pdf"):
48
+ reader = PdfReader(file)
49
+ text = "".join(page.extract_text() or "" for page in reader.pages)
50
+ return text
51
+ elif name.endswith(".docx"):
52
+ data = file.read()
53
+ doc = Document(io.BytesIO(data))
54
+ return " ".join(p.text for p in doc.paragraphs)
55
+ elif name.endswith(".csv"):
56
+ data = file.read()
57
+ for enc in ("utf-8", "latin-1"):
58
+ try:
59
+ df = pd.read_csv(io.BytesIO(data), encoding=enc)
60
+ return " ".join(df.astype(str).values.flatten().tolist())
61
+ except: pass
62
+ return ""
63
+ elif name.endswith(".txt"):
64
+ raw = file.read()
65
+ for enc in ("utf-8", "latin-1"):
66
+ try: return raw.decode(enc, errors="ignore")
67
+ except: continue
68
+ return raw.decode("utf-8", errors="ignore")
69
+ else:
70
+ return ""
71
+
72
+ def fetch_web_text(url):
73
+ try:
74
+ headers = {'User-Agent': 'Mozilla/5.0'}
75
+ resp = requests.get(url, headers=headers, timeout=10)
76
+ resp.raise_for_status()
77
+ soup = BeautifulSoup(resp.text, "html.parser")
78
+ for tag in soup(["script", "style", "noscript"]):
79
+ tag.decompose()
80
+ return " ".join(soup.get_text(separator=" ").split())
81
+ except Exception:
82
+ return ""
83
+
84
+ def chunk_docs(docs, chunk_size=1000, chunk_overlap=120):
85
  splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
86
  chunks = []
87
  for doc in docs:
88
+ splits = splitter.split_text(doc["text"])
89
  for idx, chunk in enumerate(splits):
90
  chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk})
91
  return chunks
92
 
93
+ def build_index_and_chunks(docs):
94
+ chunks = chunk_docs(docs)
95
+ texts = [chunk["content"] for chunk in chunks]
96
+ if len(texts) == 0: return None, []
97
+ embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True)
98
+ embeddings = np.asarray(embeddings).astype("float32")
99
+ dim = embeddings.shape[1]
100
+ index = faiss.IndexFlatL2(dim)
101
+ index.add(embeddings)
102
+ return index, chunks
103
+
104
+ def retrieve(query, index, chunks, top_k=3):
105
+ if index is None or len(chunks) == 0:
 
 
 
 
 
106
  return []
107
  q_emb = embedder.encode([query], convert_to_numpy=True)
108
+ q_emb = np.asarray(q_emb).astype("float32")
109
  distances, indices = index.search(q_emb, top_k)
110
+ results = []
111
+ for dist, idx in zip(distances[0], indices[0]):
112
+ if idx >= 0 and idx < len(chunks):
113
+ results.append({"chunk": chunks[idx], "score": float(dist)})
114
+ return results
115
+
116
+ def answer_question(query, index, chunks):
117
+ results = retrieve(query, index, chunks)
118
+ context_chunks = [r["chunk"] for r in results]
 
 
 
 
 
 
 
 
 
119
  context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks)
120
  prompt = (
121
  "Answer the following question using ONLY the provided context and cite the chunk ids used.\n"
122
+ f"Question: {query}\nContext:\n{context_text}\nAnswer with citations:"
 
 
 
123
  )
124
+ generated = llm(prompt, max_length=512, num_return_sequences=1)
125
+ return generated[0]["generated_text"], "\n".join(f"[{c['chunk_id']} from {c['source']}]" for c in context_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ state = {"index": None, "chunks": []}
 
128
 
129
+ def process(files, urls):
130
+ docs = []
131
+ if files:
132
+ for f in files:
133
+ text = load_file_text(f)
134
+ if text:
135
+ docs.append({"source": f.name, "text": text})
136
+ if urls:
137
+ for url in urls.strip().splitlines():
138
+ text = fetch_web_text(url.strip())
139
+ if text:
140
+ docs.append({"source": url.strip(), "text": text})
141
+ if len(docs) == 0:
142
+ return "No documents or URLs loaded."
143
+ index, chunks = build_index_and_chunks(docs)
144
+ state["index"], state["chunks"] = index, chunks
145
+ return f"Loaded {len(docs)} docs, created {len(chunks)} chunks."
146
+
147
+ def chat_response(user_message, history):
148
+ if state["index"] is None or len(state["chunks"]) == 0:
149
+ bot_message = "Please upload documents or enter URLs, then press 'Load & Process' first."
150
+ else:
151
+ answer, sources = answer_question(user_message, state["index"], state["chunks"])
152
+ bot_message = answer + "\n\nSources:\n" + sources
153
+ history = history or []
154
+ history.append(("User: " + user_message, "Assistant: " + bot_message))
155
+ return "", history
156
+
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown("# 📚 RAG Chatbot with Mistral-7B and FAISS")
159
+
160
+ with gr.Row():
161
+ with gr.Column(scale=1):
162
+ file_input = gr.File(label="Upload Files (PDF, DOCX, TXT, CSV)", file_types=[".pdf", ".docx", ".txt", ".csv"], file_count="multiple")
163
+ url_input = gr.Textbox(label="Enter URLs (one per line)", lines=4)
164
+ process_button = gr.Button("Load & Process Documents and URLs")
165
+ output_log = gr.Textbox(label="Status")
166
+
167
+ with gr.Column(scale=2):
168
+ chatbot = gr.Chatbot()
169
+ user_input = gr.Textbox(placeholder="Ask a question about the loaded documents...", show_label=False)
170
+ submit_btn = gr.Button("Send")
171
+
172
+ process_button.click(process, inputs=[file_input, url_input], outputs=output_log)
173
+ submit_btn.click(chat_response, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
174
+
175
+ demo.launch()