SiddhJagani commited on
Commit
9a7f6cc
Β·
verified Β·
1 Parent(s): a03ef00

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +322 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,324 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ import os
3
+ import hashlib
4
+ import requests
5
+ from pathlib import Path
6
  import streamlit as st
7
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
8
+ from llama_index.core.node_parser import SentenceSplitter
9
+ from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse
10
+ import logging
11
+ from typing import List, Any, Generator, AsyncGenerator
12
 
13
+ # --------------------------------------------------------------
14
+ # 1. SILENCE NOISY LOGS
15
+ # --------------------------------------------------------------
16
+ logging.getLogger("llama_index").setLevel(logging.CRITICAL)
17
+
18
+ # --------------------------------------------------------------
19
+ # 2. IMPORT FAST EMBEDDER (ModernBERT / MLX)
20
+ # --------------------------------------------------------------
21
+ from embedder import embedder # <-- ModernBERT 4-bit local embedder
22
+
23
+ # --------------------------------------------------------------
24
+ # 3. LLAMA-INDEX EMBEDDING WRAPPER
25
+ # --------------------------------------------------------------
26
+ from llama_index.core.embeddings import BaseEmbedding
27
+
28
+ class LlamaIndexWrapper(BaseEmbedding):
29
+ def __init__(self, dim: int = 768):
30
+ super().__init__()
31
+ self._dimension = dim
32
+
33
+ def _get_query_embedding(self, query: str) -> List[float]:
34
+ return embedder.embed_query(query)
35
+
36
+ def _get_text_embedding(self, text: str) -> List[float]:
37
+ return embedder.embed_query(text)
38
+
39
+ def _get_text_embedding_batch(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
40
+ return embedder.embed_documents(texts)
41
+
42
+ async def _aget_query_embedding(self, query: str) -> List[float]:
43
+ return self._get_query_embedding(query)
44
+
45
+ async def _aget_text_embedding(self, text: str) -> List[float]:
46
+ return self._get_text_embedding(text)
47
+
48
+ async def _aget_text_embedding_batch(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
49
+ return self._get_text_embedding_batch(texts, **kwargs)
50
+
51
+ @property
52
+ def dimension(self) -> int:
53
+ return self._dimension
54
+
55
+
56
+ embed_model = LlamaIndexWrapper(dim=768)
57
+
58
+ # --------------------------------------------------------------
59
+ # 4. CONFIG
60
+ # --------------------------------------------------------------
61
+ TEMP_DIR = "temp_repo"
62
+ OUTPUT_DIR = "output"
63
+ LLM_API = "http://116.72.105.227:1234/v1"
64
+
65
+ # --------------------------------------------------------------
66
+ # 5. HELPER β€” convert any Response to string
67
+ # --------------------------------------------------------------
68
+ def to_text(resp):
69
+ """Convert LlamaIndex Response or string-like objects safely to text."""
70
+ if resp is None:
71
+ return ""
72
+ if hasattr(resp, "response"):
73
+ return resp.response
74
+ if hasattr(resp, "text"):
75
+ return resp.text
76
+ return str(resp)
77
+
78
+ # --------------------------------------------------------------
79
+ # 6. AUTO-DETECT MODEL FROM LM-STUDIO
80
+ # --------------------------------------------------------------
81
+ def get_lmstudio_model():
82
+ try:
83
+ r = requests.get(f"{LLM_API}/models", timeout=5)
84
+ if r.status_code == 200:
85
+ models = r.json().get("data", [])
86
+ if models:
87
+ return models[0]["id"]
88
+ except Exception as e:
89
+ st.warning(f"Auto-detect failed: {e}. Using default model.")
90
+ return "Qwen2.5-Coder-7B-Instruct"
91
+
92
+ # --------------------------------------------------------------
93
+ # 7. CUSTOM LLM (LM-STUDIO API)
94
+ # --------------------------------------------------------------
95
+ class LMStudioLLM(CustomLLM):
96
+ model_name: str
97
+ temperature: float = 0.7
98
+ context_window: int = 32768
99
+ num_output: int = -1
100
+ model_config = {"extra": "allow"}
101
+
102
+ def __init__(self, model_name: str, temperature: float = 0.7):
103
+ super().__init__(model_name=model_name, temperature=temperature)
104
+ self.base_url = "http://116.72.105.227:1234/v1"
105
+
106
+ @property
107
+ def metadata(self) -> LLMMetadata:
108
+ return LLMMetadata(context_window=self.context_window, num_output=self.num_output)
109
+
110
+ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
111
+ payload = {
112
+ "model": self.model_name,
113
+ "messages": [{"role": "user", "content": prompt}],
114
+ "temperature": self.temperature,
115
+ "max_tokens": self.num_output,
116
+ "stream": False,
117
+ **kwargs,
118
+ }
119
+ resp = requests.post(f"{self.base_url}/chat/completions", json=payload, timeout=300)
120
+ resp.raise_for_status()
121
+ text = resp.json()["choices"][0]["message"]["content"]
122
+ return CompletionResponse(text=text)
123
+
124
+ async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
125
+ return self.complete(prompt, **kwargs)
126
+
127
+ def stream_complete(self, prompt: str, **kwargs: Any) -> Generator[CompletionResponse, None, None]:
128
+ yield self.complete(prompt, **kwargs)
129
+
130
+ async def astream_complete(self, prompt: str, **kwargs: Any) -> AsyncGenerator[CompletionResponse, None]:
131
+ yield self.complete(prompt, **kwargs)
132
+
133
+
134
+ # --------------------------------------------------------------
135
+ # 8. EXTRACT ZIP (clean old files first)
136
+ # --------------------------------------------------------------
137
+ def extract_repo(zip_path: str):
138
+ os.makedirs(TEMP_DIR, exist_ok=True)
139
+ import shutil
140
+ for item in os.listdir(TEMP_DIR):
141
+ p = os.path.join(TEMP_DIR, item)
142
+ if os.path.isdir(p):
143
+ shutil.rmtree(p)
144
+ else:
145
+ os.unlink(p)
146
+ with zipfile.ZipFile(zip_path, "r") as z:
147
+ z.extractall(TEMP_DIR)
148
+ st.success(f"βœ… Extracted β†’ `{TEMP_DIR}`")
149
+
150
+ # --------------------------------------------------------------
151
+ # 9. BUILD INDEX (NO CACHING)
152
+ # --------------------------------------------------------------
153
+ def build_index(_repo_hash: str):
154
+ import shutil
155
+ if not os.path.isdir(TEMP_DIR) or not os.listdir(TEMP_DIR):
156
+ st.error("No files extracted!")
157
+ return None
158
+
159
+ # Clear any old persisted index
160
+ storage_dir = "storage"
161
+ if os.path.exists(storage_dir):
162
+ shutil.rmtree(storage_dir)
163
+
164
+ docs = SimpleDirectoryReader(
165
+ TEMP_DIR,
166
+ recursive=True,
167
+ exclude=[
168
+ "*.test.py", "*__pycache__*", "*.pyc", "*.log",
169
+ "*.mp3", "*.wav", "*.m4a", "*.mp4", "*.mov", "*.avi", "*.flac", "*.ogg",
170
+ "node_modules", ".git", ".venv", "*.md"
171
+ ],
172
+ ).load_data()
173
+
174
+ if not docs:
175
+ st.error("No documents loaded!")
176
+ return None
177
+
178
+ splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=50)
179
+ nodes = splitter.get_nodes_from_documents(docs)
180
+
181
+ index = VectorStoreIndex(nodes, embed_model=embed_model, embed_batch_size=32)
182
+ st.success("βœ… Index built fresh (no caching)!")
183
+ return index
184
+
185
+ # --------------------------------------------------------------
186
+ # 10. MAIN STREAMLIT APP
187
+ # --------------------------------------------------------------
188
+ def main():
189
+ st.title("πŸ€– AI Codebase β†’ Docs Agent (ModernBERT + LM-Studio)")
190
+
191
+ auto_detected = get_lmstudio_model()
192
+ st.info(f"**Auto-detected LM-Studio model:** `{auto_detected}`")
193
+
194
+ available_models = [
195
+ auto_detected,
196
+ ]
197
+ selected_model = st.selectbox("Select LLM (loaded in LM-Studio)", available_models, index=0)
198
+ llm = LMStudioLLM(model_name=selected_model, temperature=0.7)
199
+
200
+ uploaded = st.file_uploader("πŸ“¦ Upload GitHub Repo (.zip)", type="zip")
201
+ if not uploaded:
202
+ st.info("Upload a .zip β†’ Click **Start Analysis**")
203
+ return
204
+
205
+ zip_path = "repo.zip"
206
+ with open(zip_path, "wb") as f:
207
+ f.write(uploaded.getbuffer())
208
+
209
+ if st.button("πŸš€ Start Analysis"):
210
+ with st.spinner("Extracting repository..."):
211
+ extract_repo(zip_path)
212
+
213
+ repo_hash = hashlib.md5(open(zip_path, "rb").read()).hexdigest()
214
+
215
+ with st.spinner("Building knowledge base..."):
216
+ index = build_index(repo_hash)
217
+ if not index:
218
+ return
219
+ engine = index.as_query_engine(llm=llm)
220
+
221
+ # === 1. Overview ===
222
+ with st.expander("πŸ“˜ 1. Project Overview", expanded=True):
223
+ overview = to_text(engine.query(
224
+ "Analyze the codebase and summarize:\n"
225
+ "- Project name\n- Description\n- Tech stack\n- Entry point\n- Folder structure overview."
226
+ ))
227
+ st.markdown(overview)
228
+ st.session_state.overview = overview
229
+
230
+ # === 2. Generate README ===
231
+ with st.expander("🧾 2. Generate README.md", expanded=True):
232
+ readme = to_text(engine.query(
233
+ f"Using this project overview:\n{st.session_state.overview}\n\n"
234
+ "Generate a **professional and structured README.md** including:\n"
235
+ "- # Title\n- ## Description\n- ## Features\n- ## Installation\n"
236
+ "- ## Usage\n- ## API Reference\n- ## Folder Structure (in ##Folder Structure section/block)\n- ## Contributing\n"
237
+ "- ## License\nEnsure Markdown syntax is perfect with spacing and headers."
238
+ ))
239
+ st.markdown(readme)
240
+ st.session_state.readme = readme
241
+
242
+ # === 3. Verification & Auto-Fix ===
243
+ with st.expander("πŸ” 3. Self-Verification & Auto-Fix", expanded=True):
244
+ check = to_text(engine.query(
245
+ f"README:\n{st.session_state.readme}\n\n"
246
+ "Review ALL code files and verify the README accuracy.\n"
247
+ "Check for incorrect function/class names, wrong dependencies, or invalid setup steps.\n"
248
+ "If issues are found, summarize them clearly. Otherwise say 'ALL CORRECT'."
249
+ ))
250
+ st.markdown(check)
251
+
252
+ if "all correct" not in check.lower():
253
+ st.warning("Fixing README automatically...")
254
+ fixed = to_text(engine.query(
255
+ f"Fix and improve the README.md based on these verification results:\n{check}\n\n"
256
+ f"Here is the original README:\n{st.session_state.readme}\n\n"
257
+ "Ensure the final version is perfectly formatted Markdown, with consistent headings and spacing."
258
+ ))
259
+ st.success("βœ… Fixed README generated!")
260
+ st.markdown("**Final README.md:**")
261
+ st.markdown(fixed)
262
+ st.session_state.readme_fixed = fixed
263
+ else:
264
+ st.session_state.readme_fixed = st.session_state.readme
265
+ st.success("βœ… README verified as correct!")
266
+
267
+ # === 4. Architecture Diagram ===
268
+ with st.expander("🧩 4. Architecture Diagram", expanded=True):
269
+ diag = to_text(engine.query(
270
+ "Generate a **Mermaid** flowchart of the application architecture:\n"
271
+ "- Components and relationships\n- Data flow\n- APIs / Services / DB\n"
272
+ "Return only valid Markdown with ```mermaid code block."
273
+ ))
274
+ st.code(diag, language="mermaid")
275
+
276
+
277
+
278
+
279
+
280
+ # === 5. Export ===
281
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
282
+ readme_original_path = Path(f"{OUTPUT_DIR}/README_original.md")
283
+ readme_fixed_path = Path(f"{OUTPUT_DIR}/README_final.md")
284
+ diagram_path = Path(f"{OUTPUT_DIR}/ARCHITECTURE.mmd")
285
+
286
+ readme_original_path.write_text(st.session_state.readme)
287
+ readme_fixed_path.write_text(st.session_state.readme_fixed)
288
+ diagram_path.write_text(diag)
289
+
290
+ st.success(f"πŸ“ Exported all files β†’ `{OUTPUT_DIR}/`")
291
+
292
+ # --- πŸͺ„ Download Buttons ---
293
+ st.markdown("### πŸ“₯ Download Your Files")
294
+
295
+ with open(readme_fixed_path, "rb") as f:
296
+ st.download_button(
297
+ label="⬇️ Download Final README.md",
298
+ data=f,
299
+ file_name="README.md",
300
+ mime="text/markdown",
301
+ )
302
+
303
+ with open(readme_original_path, "rb") as f:
304
+ st.download_button(
305
+ label="⬇️ Download Original README.md",
306
+ data=f,
307
+ file_name="README_original.md",
308
+ mime="text/markdown",
309
+ )
310
+
311
+ with open(diagram_path, "rb") as f:
312
+ st.download_button(
313
+ label="⬇️ Download Architecture Diagram (.mmd)",
314
+ data=f,
315
+ file_name="ARCHITECTURE.mmd",
316
+ mime="text/plain",
317
+ )
318
+
319
+ st.info("βœ… You can also find these files saved in the `output/` folder locally.")
320
+
321
+
322
+ # --------------------------------------------------------------
323
+ if __name__ == "__main__":
324
+ main()