nwamgbowo commited on
Commit
94f335b
·
verified ·
1 Parent(s): 26a7ff5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +252 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,259 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
 
6
  """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ))
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
 
 
3
 
4
  """
5
+ build_and_deploy_nitda_rag.py
6
 
7
+ Creates a Space-ready NITDA RAG project (Gradio app) and optionally uploads it to Hugging Face Spaces.
 
 
8
 
9
+ Usage examples:
10
+ # 1) Just create the project locally
11
+ python build_and_deploy_nitda_rag.py --project nitda-rag
12
+
13
+ # 2) Create + Deploy (requires HF_TOKEN env var with write access)
14
+ export HF_TOKEN=hf_xxx_your_access_token
15
+ python build_and_deploy_nitda_rag.py --project nitda-rag --space-id nwamgbowo/nitda-rag --deploy
16
+
17
+ After deployment, open:
18
+ https://huggingface.co/spaces/nwamgbowo/nitda-rag
19
+
20
+ Then, in the app UI, click "Initialize (build index + load model)" and ask questions.
21
  """
22
 
23
+ import os
24
+ import sys
25
+ import argparse
26
+ from pathlib import Path
27
+ from textwrap import dedent
28
+
29
+ # ----------------------------
30
+ # File contents
31
+ # ----------------------------
32
+ APP_PY = dedent(r'''
33
+ import os
34
+ import time
35
+ import traceback
36
+ from typing import List
37
+
38
+ import gradio as gr
39
+
40
+ # Use LangChain community packages to avoid import drift
41
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
42
+ from langchain_community.document_loaders import PyMuPDFLoader
43
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
44
+ from langchain_community.vectorstores import Chroma
45
+
46
+ from huggingface_hub import hf_hub_download
47
+ from llama_cpp import Llama
48
+
49
+ # -----------------------------
50
+ # Config
51
+ # -----------------------------
52
+ DOCS_DIR = "data" # where PDFs live inside the Space
53
+ DB_DIR = "nitda_db" # Chroma persistence directory
54
+
55
+ TOP_K = 3
56
+ CHUNK_SIZE = 1000
57
+ CHUNK_OVERLAP = 150
58
+ CTX_LEN = 2048
59
+
60
+ # Primary model: Mistral-7B (GPU recommended; CPU Spaces may OOM)
61
+ PRIMARY_REPO = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
62
+ PRIMARY_FILE = "mistral-7b-instruct-v0.2.Q6_K.gguf"
63
+ PRIMARY_PARAMS = dict(
64
+ n_ctx=CTX_LEN,
65
+ n_threads=os.cpu_count() or 4,
66
+ n_batch=256,
67
+ n_gpu_layers=int(os.getenv("LLM_N_GPU_LAYERS", "0")), # set >0 on GPU Space
68
+ verbose=False
69
+ )
70
+
71
+ # Fallback: TinyLlama (CPU-friendly, reliable on CPU Spaces)
72
+ FALLBACK_REPO = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
73
+ FALLBACK_FILE = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
74
+ FALLBACK_PARAMS = dict(
75
+ n_ctx=CTX_LEN,
76
+ n_threads=os.cpu_count() or 4,
77
+ n_batch=128,
78
+ n_gpu_layers=0,
79
+ verbose=False
80
+ )
81
+
82
+ SYSTEM_MESSAGE = (
83
+ "You are an AI assistant specialized in NITDA information retrieval. "
84
+ "Answer strictly from the provided context (official NITDA documents). "
85
+ "If the answer is not in the context, say you don't know."
86
+ )
87
+
88
+ QNA_TEMPLATE = """[SYSTEM]
89
+ {system}
90
+
91
+ [CONTEXT]
92
+ {context}
93
+
94
+ [USER QUESTION]
95
+ {question}
96
+
97
+ [ASSISTANT]
98
+ """
99
+
100
+ # -----------------------------
101
+ # Helpers
102
+ # -----------------------------
103
+ def list_pdfs(folder: str):
104
+ os.makedirs(folder, exist_ok=True)
105
+ return [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(".pdf")]
106
+
107
+ def build_or_load_vectorstore():
108
+ """Load existing Chroma DB if present; else build from PDFs in data/."""
109
+ if os.path.isdir(DB_DIR) and os.listdir(DB_DIR):
110
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
111
+ return Chroma(persist_directory=DB_DIR, embedding_function=embeddings)
112
+
113
+ pdfs = list_pdfs(DOCS_DIR)
114
+ if not pdfs:
115
+ raise FileNotFoundError(f"No PDFs found in '{DOCS_DIR}'. Upload your PDFs to the 'data/' folder.")
116
+
117
+ docs = []
118
+ for p in pdfs:
119
+ loader = PyMuPDFLoader(p)
120
+ docs.extend(loader.load())
121
+
122
+ splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
123
+ chunks = splitter.split_documents(docs)
124
+
125
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
126
+ vs = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=DB_DIR)
127
+ vs.persist()
128
+ return vs
129
+
130
+ def load_llm():
131
+ """
132
+ Try to load primary (Mistral model). If it fails (OOM on CPU Space),
133
+ fallback to TinyLlama automatically. You can force fallback by setting
134
+ Space Variable USE_TINYLLAMA=1.
135
+ """
136
+ if os.getenv("USE_TINYLLAMA", "0") == "1":
137
+ model_path = hf_hub_download(repo_id=FALLBACK_REPO, filename=FALLBACK_FILE)
138
+ return Llama(model_path=model_path, **FALLBACK_PARAMS)
139
+
140
+ try:
141
+ model_path = hf_hub_download(repo_id=PRIMARY_REPO, filename=PRIMARY_FILE)
142
+ return Llama(model_path=model_path, **PRIMARY_PARAMS)
143
+ except Exception as e:
144
+ print(f"[WARN] Primary model load failed: {e}. Falling back to TinyLlama.")
145
+ model_path = hf_hub_download(repo_id=FALLBACK_REPO, filename=FALLBACK_FILE)
146
+ return Llama(model_path=model_path, **FALLBACK_PARAMS)
147
+
148
+ def render_context(docs):
149
+ parts = []
150
+ for i, d in enumerate(docs, 1):
151
+ meta = d.metadata or {}
152
+ src = meta.get("source", "document")
153
+ page = meta.get("page", None)
154
+ tag = f"{src}" + (f" (page {page})" if page is not None else "")
155
+ parts.append(f"[{i}] {tag}\n{d.page_content}")
156
+ return "\n\n".join(parts)
157
+
158
+ def generate_answer(question, retriever, llm):
159
+ if not question.strip():
160
+ return "Please enter a question."
161
+ try:
162
+ hits = retriever.get_relevant_documents(question)
163
+ if not hits:
164
+ return "I couldn't find relevant context in the documents."
165
+ context = render_context(hits)
166
+ prompt = QNA_TEMPLATE.format(system=SYSTEM_MESSAGE, context=context, question=question.strip())
167
+
168
+ out = llm(
169
+ prompt=prompt,
170
+ max_tokens=512,
171
+ temperature=0.2,
172
+ top_p=0.95,
173
+ repeat_penalty=1.1,
174
+ stop=["</s>", "[USER QUESTION]", "[SYSTEM]"]
175
+ )
176
+ return out.get("choices", [{}])[0].get("text", "").strip() or "The model returned no text."
177
+ except Exception as e:
178
+ return f"Error generating answer:\n{e}\n\n{traceback.format_exc()}"
179
+
180
+ # -----------------------------
181
+ # Gradio App (lazy init)
182
+ # -----------------------------
183
+ with gr.Blocks(title="NITDA RAG Assistant") as demo:
184
+ gr.Markdown("## NITDA RAG Assistant\nAsk questions based on official NITDA documents in the `data/` folder.")
185
+
186
+ retriever_state = gr.State(None)
187
+ llm_state = gr.State(None)
188
+
189
+ status = gr.Markdown("**Status:** Not initialized.")
190
+ init_btn = gr.Button("Initialize (build index + load model)")
191
+
192
+ def init_resources():
193
+ t0 = time.time()
194
+ vs = build_or_load_vectorstore()
195
+ retriever = vs.as_retriever(search_type="similarity", search_kwargs={"k": TOP_K})
196
+ llm = load_llm()
197
+ dt = time.time() - t0
198
+ return retriever, llm, f"**Status:** Ready in {dt:.1f}s."
199
+
200
+ init_btn.click(fn=lambda: init_resources(), inputs=None, outputs=[retriever_state, llm_state, status])
201
+
202
+ q = gr.Textbox(label="Your question", placeholder="Ask about NITDA...", lines=2)
203
+ a = gr.Markdown()
204
+ ask_btn = gr.Button("Ask")
205
+
206
+ def on_ask(question, retriever, llm):
207
+ if retriever is None or llm is None:
208
+ return "Please click **Initialize (build index + load model)** first."
209
+ return generate_answer(question, retriever, llm)
210
+
211
+ ask_btn.click(on_ask, inputs=[q, retriever_state, llm_state], outputs=[a])
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch(server_name="0.0.0.0", server_port=7860)
215
+ ''').strip() + "\n"
216
+
217
+ REQUIREMENTS_TXT = dedent(r'''
218
+ # UI
219
+ gradio==4.37.2
220
+
221
+ # LLM runtime
222
+ llama-cpp-python==0.2.60
223
+ huggingface_hub==0.23.5
224
+
225
+ # LangChain stable community integrations
226
+ langchain==0.1.16
227
+ langchain-community==0.0.34
228
+ langchain-text-splitters==0.0.1
229
+
230
+ # Vector DB + embeddings
231
+ chromadb==0.4.24
232
+ sentence-transformers==2.7.0
233
+
234
+ # PDF loader
235
+ pymupdf==1.23.26
236
+
237
+ # Utils
238
+ numpy==1.26.4
239
+ pandas==2.1.4
240
+ ''').strip() + "\n"
241
+
242
+ RUNTIME_TXT = "python-3.10\n"
243
+
244
+ DATA_README = dedent(r'''
245
+ # Data folder
246
+
247
+
248
+ Place your NITDA PDFs here. Example filenames:
249
+
250
+ python build_and_deploy_nitda_rag.py \
251
+ --space-id nwamgbowo/nitda-rag \
252
+ --pdf "/path/to/NITDA-ACT-2007-2019-Edition1.pdf" \
253
+ --pdf "/path/to/Digital-Literacy-Framework.pdf" \
254
+ --pdf "/path/to/FrameworkAndGuidelinesForPublicInternetAccessPIA1.pdf" \
255
+ --pdf "/path/to/NATIONAL-REGULATORY-GUIDELINE-FOR-ELECTRONIC-INVOICING-IN-NIGERIA-2025.pdf"
256
+
257
+
258
+ ''').strip() + "\n"
259
  ))