Github / src /streamlit_app.py
SiddhJagani's picture
Update src/streamlit_app.py
61c0e4c verified
import zipfile
import os
import hashlib
import requests
from pathlib import Path
import streamlit as st
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse
import logging
from typing import List, Any, Generator, AsyncGenerator
import shutil
import json # βœ… added for parsing streaming chunks
# --------------------------------------------------------------
# 1. SILENCE NOISY LOGS
# --------------------------------------------------------------
logging.getLogger("llama_index").setLevel(logging.CRITICAL)
# --------------------------------------------------------------
# 2. IMPORT FAST EMBEDDER (ModernBERT CPU-Compatible)
# --------------------------------------------------------------
from embedder import embedder # <-- CPU ModernBERT embedder (not MLX)
# --------------------------------------------------------------
# 3. LLAMA-INDEX EMBEDDING WRAPPER
# --------------------------------------------------------------
from llama_index.core.embeddings import BaseEmbedding
class LlamaIndexWrapper(BaseEmbedding):
def __init__(self, dim: int = 768):
super().__init__()
self._dimension = dim
def _get_query_embedding(self, query: str) -> List[float]:
return embedder.embed_query(query)
def _get_text_embedding(self, text: str) -> List[float]:
return embedder.embed_query(text)
def _get_text_embedding_batch(
self, texts: List[str], **kwargs: Any
) -> List[List[float]]:
return embedder.embed_documents(texts)
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
async def _aget_text_embedding_batch(
self, texts: List[str], **kwargs: Any
) -> List[List[float]]:
return self._get_text_embedding_batch(texts, **kwargs)
@property
def dimension(self) -> int:
return self._dimension
embed_model = LlamaIndexWrapper(dim=768)
# --------------------------------------------------------------
# 4. CONFIG
# --------------------------------------------------------------
TEMP_DIR = "temp_repo"
OUTPUT_DIR = "output"
LLM_API = os.getenv("BaseURL") # <- LM Studio endpoint
# --------------------------------------------------------------
# 5. HELPER β€” convert any Response to string
# --------------------------------------------------------------
def to_text(resp):
"""Convert LlamaIndex Response or string-like objects safely to text."""
if resp is None:
return ""
if hasattr(resp, "response"):
return resp.response
if hasattr(resp, "text"):
return resp.text
return str(resp)
# --------------------------------------------------------------
# 6. AUTO-DETECT MODEL (optional – HF Space only has one)
# --------------------------------------------------------------
def get_hf_model():
return "openai/gpt-4o" # Fixed for your endpoint
# --------------------------------------------------------------
# 7. CUSTOM LLM – HUGGING FACE SPACE (streaming + auth)
# --------------------------------------------------------------
class HFChatLLM(CustomLLM):
model_name: str
temperature: float = 0.7
context_window: int = 32768
num_output: int = -1
model_config = {"extra": "allow"}
def __init__(self, model_name: str, temperature: float = 0.7):
super().__init__(model_name=model_name, temperature=temperature)
self.base_url = "https://siddhjagani-backend.hf.space/v1/chat/completions"
self.headers = {
"Authorization": os.getenv("API_KEY"),
"Content-Type": "application/json",
}
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(context_window=self.context_window, num_output=self.num_output)
# ------------------------------------------------------------------
# SYNC COMPLETE
# ------------------------------------------------------------------
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": self.num_output if self.num_output > 0 else 2048,
"stream": False,
**kwargs,
}
try:
resp = requests.post(self.base_url, headers=self.headers, json=payload, timeout=300)
resp.raise_for_status()
text = resp.json()["choices"][0]["message"]["content"]
except Exception as e:
text = f"[HF LLM Error]: {e}"
return CompletionResponse(text=text)
# ------------------------------------------------------------------
# STREAMING (token-by-token)
# ------------------------------------------------------------------
def stream_complete(self, prompt: str, **kwargs: Any) -> Generator[CompletionResponse, None, None]:
payload = {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": self.temperature,
"max_tokens": self.num_output if self.num_output > 0 else 2048,
"stream": True,
**kwargs,
}
try:
with requests.post(
self.base_url, headers=self.headers, json=payload, stream=True, timeout=300
) as resp:
resp.raise_for_status()
for line in resp.iter_lines(decode_unicode=True):
if not line or not line.startswith("data: "):
continue
data = line.removeprefix("data: ").strip()
if data == "[DONE]":
break
try:
chunk = json.loads(data)
delta = chunk["choices"][0]["delta"].get("content", "")
if delta:
yield CompletionResponse(text=delta)
except Exception:
continue
except Exception as e:
yield CompletionResponse(text=f"[Streaming Error]: {e}\n")
async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return self.complete(prompt, **kwargs)
async def astream_complete(self, prompt: str, **kwargs: Any) -> AsyncGenerator[CompletionResponse, None]:
yield self.complete(prompt, **kwargs)
# --------------------------------------------------------------
# 8. EXTRACT ZIP (clean old files first)
# --------------------------------------------------------------
def extract_repo(zip_path: str):
os.makedirs(TEMP_DIR, exist_ok=True)
for item in os.listdir(TEMP_DIR):
p = os.path.join(TEMP_DIR, item)
if os.path.isdir(p):
shutil.rmtree(p)
else:
os.unlink(p)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(TEMP_DIR)
st.success(f"βœ… Extracted β†’ `{TEMP_DIR}`")
# --------------------------------------------------------------
# 9. BUILD INDEX (NO CACHING)
# --------------------------------------------------------------
def build_index(_repo_hash: str):
if not os.path.isdir(TEMP_DIR) or not os.listdir(TEMP_DIR):
st.error("No files extracted!")
return None
storage_dir = "storage"
if os.path.exists(storage_dir):
shutil.rmtree(storage_dir)
docs = SimpleDirectoryReader(
TEMP_DIR,
recursive=True,
exclude=[
"*.test.py", "*__pycache__*", "*.pyc", "*.log",
"*.mp3", "*.wav", "*.m4a", "*.mp4", "*.mov", "*.avi", "*.flac", "*.ogg",
"node_modules", ".git", ".venv", "*.md"
],
).load_data()
if not docs:
st.error("No documents loaded!")
return None
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=50)
nodes = splitter.get_nodes_from_documents(docs)
index = VectorStoreIndex(nodes, embed_model=embed_model, embed_batch_size=16)
st.success("βœ… Index built fresh (no caching)!")
return index
# --------------------------------------------------------------
# 10. MAIN STREAMLIT APP
# --------------------------------------------------------------
def main():
st.title("πŸ€– AI Codebase β†’ Docs Agent (ModernBERT CPU + LM-Studio Streaming)")
# Fixed model
selected_model = get_hf_model()
st.info(f"**Using Openai's model:** `{selected_model}`")
# Create LLM
llm = HFChatLLM(model_name=selected_model, temperature=0.7)
uploaded = st.file_uploader("πŸ“¦ Upload GitHub Repo (.zip)", type="zip")
if not uploaded:
st.info("Upload a .zip β†’ Click **Start Analysis**")
return
zip_path = "repo.zip"
with open(zip_path, "wb") as f:
f.write(uploaded.getbuffer())
if st.button("πŸš€ Start Analysis"):
with st.spinner("Extracting repository..."):
extract_repo(zip_path)
repo_hash = hashlib.md5(open(zip_path, "rb").read()).hexdigest()
with st.spinner("Building knowledge base..."):
index = build_index(repo_hash)
if not index:
return
engine = index.as_query_engine(llm=llm)
# === 1. Overview (Streamed) ===
with st.expander("πŸ“˜ 1. Project Overview", expanded=True):
placeholder = st.empty()
streamed_text = ""
for token in llm.stream_complete(
"Analyze the codebase and summarize:\n"
"- Project name\n- Description\n- Tech stack\n- Entry point\n- Folder structure overview."
):
streamed_text += token.text
placeholder.markdown(streamed_text)
st.session_state.overview = streamed_text
# === 2. Generate README (Streamed) ===
with st.expander("🧾 2. Generate README.md", expanded=True):
placeholder = st.empty()
streamed_text = ""
for token in llm.stream_complete(
f"Using this project overview:\n{st.session_state.overview}\n\n"
"Generate a **professional and structured README.md** including:\n"
"- # Title\n- ## Description\n- ## Features\n- ## Installation\n"
"- ## Usage\n- ## API Reference\n- ## Folder Structure\n- ## Contributing\n- ## License\n"
"Ensure Markdown syntax is perfect with spacing and headers."
):
streamed_text += token.text
placeholder.markdown(streamed_text)
st.session_state.readme = streamed_text
# === 3. Verification & Auto-Fix (normal) ===
with st.expander("πŸ” 3. Self-Verification & Auto-Fix", expanded=True):
check = to_text(engine.query(
f"README:\n{st.session_state.readme}\n\n"
"Review all code files and verify README accuracy.\n"
"If issues are found, summarize them clearly. Otherwise say 'ALL CORRECT'."
))
st.markdown(check)
if "all correct" not in check.lower():
st.warning("Fixing README automatically...")
fixed = to_text(engine.query(
f"Fix and improve the README.md based on these verification results:\n{check}\n\n"
f"Original README:\n{st.session_state.readme}\n\n"
"Ensure the final version is perfectly formatted Markdown."
))
st.success("βœ… Fixed README generated!")
st.markdown("**Final README.md:**")
st.markdown(fixed)
st.session_state.readme_fixed = fixed
else:
st.session_state.readme_fixed = st.session_state.readme
st.success("βœ… README verified as correct!")
# === 4. Architecture Diagram (stream optional) ===
with st.expander("🧩 4. Architecture Diagram", expanded=True):
placeholder = st.empty()
streamed_text = ""
for token in llm.stream_complete(
"Generate a **Mermaid** flowchart of the application's architecture:\n"
"- Components and relationships\n- Data flow\n- APIs / Services / DB\n"
"Return only valid Markdown with ```mermaid code block."
):
streamed_text += token.text
placeholder.code(streamed_text, language="mermaid")
diag = streamed_text
# === 5. Export ===
os.makedirs(OUTPUT_DIR, exist_ok=True)
readme_original_path = Path(f"{OUTPUT_DIR}/README_original.md")
readme_fixed_path = Path(f"{OUTPUT_DIR}/README_final.md")
diagram_path = Path(f"{OUTPUT_DIR}/ARCHITECTURE.mmd")
readme_original_path.write_text(st.session_state.readme)
readme_fixed_path.write_text(st.session_state.readme_fixed)
diagram_path.write_text(diag)
st.success(f"πŸ“ Exported all files β†’ `{OUTPUT_DIR}/`")
# Download buttons
st.markdown("### πŸ“₯ Download Your Files")
for label, path, fname, mime in [
("⬇️ Download Final README.md", readme_fixed_path, "README.md", "text/markdown"),
("⬇️ Download Original README.md", readme_original_path, "README_original.md", "text/markdown"),
("⬇️ Download Architecture Diagram (.mmd)", diagram_path, "ARCHITECTURE.mmd", "text/plain"),
]:
with open(path, "rb") as f:
st.download_button(label=label, data=f, file_name=fname, mime=mime)
st.info("βœ… You can also find these files saved in the `output/` folder locally.")
# --------------------------------------------------------------
if __name__ == "__main__":
main()