import zipfile import os import hashlib import requests from pathlib import Path import streamlit as st from llama_index.core import VectorStoreIndex, SimpleDirectoryReader from llama_index.core.node_parser import SentenceSplitter from llama_index.core.llms import CustomLLM, LLMMetadata, CompletionResponse import logging from typing import List, Any, Generator, AsyncGenerator import shutil import json # ✅ added for parsing streaming chunks # -------------------------------------------------------------- # 1. SILENCE NOISY LOGS # -------------------------------------------------------------- logging.getLogger("llama_index").setLevel(logging.CRITICAL) # -------------------------------------------------------------- # 2. IMPORT FAST EMBEDDER (ModernBERT CPU-Compatible) # -------------------------------------------------------------- from embedder import embedder # <-- CPU ModernBERT embedder (not MLX) # -------------------------------------------------------------- # 3. LLAMA-INDEX EMBEDDING WRAPPER # -------------------------------------------------------------- from llama_index.core.embeddings import BaseEmbedding class LlamaIndexWrapper(BaseEmbedding): def __init__(self, dim: int = 768): super().__init__() self._dimension = dim def _get_query_embedding(self, query: str) -> List[float]: return embedder.embed_query(query) def _get_text_embedding(self, text: str) -> List[float]: return embedder.embed_query(text) def _get_text_embedding_batch( self, texts: List[str], **kwargs: Any ) -> List[List[float]]: return embedder.embed_documents(texts) async def _aget_query_embedding(self, query: str) -> List[float]: return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: return self._get_text_embedding(text) async def _aget_text_embedding_batch( self, texts: List[str], **kwargs: Any ) -> List[List[float]]: return self._get_text_embedding_batch(texts, **kwargs) @property def dimension(self) -> int: return self._dimension embed_model = LlamaIndexWrapper(dim=768) # -------------------------------------------------------------- # 4. CONFIG # -------------------------------------------------------------- TEMP_DIR = "temp_repo" OUTPUT_DIR = "output" LLM_API = os.getenv("BaseURL") # <- LM Studio endpoint # -------------------------------------------------------------- # 5. HELPER — convert any Response to string # -------------------------------------------------------------- def to_text(resp): """Convert LlamaIndex Response or string-like objects safely to text.""" if resp is None: return "" if hasattr(resp, "response"): return resp.response if hasattr(resp, "text"): return resp.text return str(resp) # -------------------------------------------------------------- # 6. AUTO-DETECT MODEL (optional – HF Space only has one) # -------------------------------------------------------------- def get_hf_model(): return "openai/gpt-4o" # Fixed for your endpoint # -------------------------------------------------------------- # 7. CUSTOM LLM – HUGGING FACE SPACE (streaming + auth) # -------------------------------------------------------------- class HFChatLLM(CustomLLM): model_name: str temperature: float = 0.7 context_window: int = 32768 num_output: int = -1 model_config = {"extra": "allow"} def __init__(self, model_name: str, temperature: float = 0.7): super().__init__(model_name=model_name, temperature=temperature) self.base_url = "https://siddhjagani-backend.hf.space/v1/chat/completions" self.headers = { "Authorization": os.getenv("API_KEY"), "Content-Type": "application/json", } @property def metadata(self) -> LLMMetadata: return LLMMetadata(context_window=self.context_window, num_output=self.num_output) # ------------------------------------------------------------------ # SYNC COMPLETE # ------------------------------------------------------------------ def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: payload = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], "temperature": self.temperature, "max_tokens": self.num_output if self.num_output > 0 else 2048, "stream": False, **kwargs, } try: resp = requests.post(self.base_url, headers=self.headers, json=payload, timeout=300) resp.raise_for_status() text = resp.json()["choices"][0]["message"]["content"] except Exception as e: text = f"[HF LLM Error]: {e}" return CompletionResponse(text=text) # ------------------------------------------------------------------ # STREAMING (token-by-token) # ------------------------------------------------------------------ def stream_complete(self, prompt: str, **kwargs: Any) -> Generator[CompletionResponse, None, None]: payload = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], "temperature": self.temperature, "max_tokens": self.num_output if self.num_output > 0 else 2048, "stream": True, **kwargs, } try: with requests.post( self.base_url, headers=self.headers, json=payload, stream=True, timeout=300 ) as resp: resp.raise_for_status() for line in resp.iter_lines(decode_unicode=True): if not line or not line.startswith("data: "): continue data = line.removeprefix("data: ").strip() if data == "[DONE]": break try: chunk = json.loads(data) delta = chunk["choices"][0]["delta"].get("content", "") if delta: yield CompletionResponse(text=delta) except Exception: continue except Exception as e: yield CompletionResponse(text=f"[Streaming Error]: {e}\n") async def acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse: return self.complete(prompt, **kwargs) async def astream_complete(self, prompt: str, **kwargs: Any) -> AsyncGenerator[CompletionResponse, None]: yield self.complete(prompt, **kwargs) # -------------------------------------------------------------- # 8. EXTRACT ZIP (clean old files first) # -------------------------------------------------------------- def extract_repo(zip_path: str): os.makedirs(TEMP_DIR, exist_ok=True) for item in os.listdir(TEMP_DIR): p = os.path.join(TEMP_DIR, item) if os.path.isdir(p): shutil.rmtree(p) else: os.unlink(p) with zipfile.ZipFile(zip_path, "r") as z: z.extractall(TEMP_DIR) st.success(f"✅ Extracted → `{TEMP_DIR}`") # -------------------------------------------------------------- # 9. BUILD INDEX (NO CACHING) # -------------------------------------------------------------- def build_index(_repo_hash: str): if not os.path.isdir(TEMP_DIR) or not os.listdir(TEMP_DIR): st.error("No files extracted!") return None storage_dir = "storage" if os.path.exists(storage_dir): shutil.rmtree(storage_dir) docs = SimpleDirectoryReader( TEMP_DIR, recursive=True, exclude=[ "*.test.py", "*__pycache__*", "*.pyc", "*.log", "*.mp3", "*.wav", "*.m4a", "*.mp4", "*.mov", "*.avi", "*.flac", "*.ogg", "node_modules", ".git", ".venv", "*.md" ], ).load_data() if not docs: st.error("No documents loaded!") return None splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=50) nodes = splitter.get_nodes_from_documents(docs) index = VectorStoreIndex(nodes, embed_model=embed_model, embed_batch_size=16) st.success("✅ Index built fresh (no caching)!") return index # -------------------------------------------------------------- # 10. MAIN STREAMLIT APP # -------------------------------------------------------------- def main(): st.title("🤖 AI Codebase → Docs Agent (ModernBERT CPU + LM-Studio Streaming)") # Fixed model selected_model = get_hf_model() st.info(f"**Using Openai's model:** `{selected_model}`") # Create LLM llm = HFChatLLM(model_name=selected_model, temperature=0.7) uploaded = st.file_uploader("📦 Upload GitHub Repo (.zip)", type="zip") if not uploaded: st.info("Upload a .zip → Click **Start Analysis**") return zip_path = "repo.zip" with open(zip_path, "wb") as f: f.write(uploaded.getbuffer()) if st.button("🚀 Start Analysis"): with st.spinner("Extracting repository..."): extract_repo(zip_path) repo_hash = hashlib.md5(open(zip_path, "rb").read()).hexdigest() with st.spinner("Building knowledge base..."): index = build_index(repo_hash) if not index: return engine = index.as_query_engine(llm=llm) # === 1. Overview (Streamed) === with st.expander("📘 1. Project Overview", expanded=True): placeholder = st.empty() streamed_text = "" for token in llm.stream_complete( "Analyze the codebase and summarize:\n" "- Project name\n- Description\n- Tech stack\n- Entry point\n- Folder structure overview." ): streamed_text += token.text placeholder.markdown(streamed_text) st.session_state.overview = streamed_text # === 2. Generate README (Streamed) === with st.expander("🧾 2. Generate README.md", expanded=True): placeholder = st.empty() streamed_text = "" for token in llm.stream_complete( f"Using this project overview:\n{st.session_state.overview}\n\n" "Generate a **professional and structured README.md** including:\n" "- # Title\n- ## Description\n- ## Features\n- ## Installation\n" "- ## Usage\n- ## API Reference\n- ## Folder Structure\n- ## Contributing\n- ## License\n" "Ensure Markdown syntax is perfect with spacing and headers." ): streamed_text += token.text placeholder.markdown(streamed_text) st.session_state.readme = streamed_text # === 3. Verification & Auto-Fix (normal) === with st.expander("🔍 3. Self-Verification & Auto-Fix", expanded=True): check = to_text(engine.query( f"README:\n{st.session_state.readme}\n\n" "Review all code files and verify README accuracy.\n" "If issues are found, summarize them clearly. Otherwise say 'ALL CORRECT'." )) st.markdown(check) if "all correct" not in check.lower(): st.warning("Fixing README automatically...") fixed = to_text(engine.query( f"Fix and improve the README.md based on these verification results:\n{check}\n\n" f"Original README:\n{st.session_state.readme}\n\n" "Ensure the final version is perfectly formatted Markdown." )) st.success("✅ Fixed README generated!") st.markdown("**Final README.md:**") st.markdown(fixed) st.session_state.readme_fixed = fixed else: st.session_state.readme_fixed = st.session_state.readme st.success("✅ README verified as correct!") # === 4. Architecture Diagram (stream optional) === with st.expander("🧩 4. Architecture Diagram", expanded=True): placeholder = st.empty() streamed_text = "" for token in llm.stream_complete( "Generate a **Mermaid** flowchart of the application's architecture:\n" "- Components and relationships\n- Data flow\n- APIs / Services / DB\n" "Return only valid Markdown with ```mermaid code block." ): streamed_text += token.text placeholder.code(streamed_text, language="mermaid") diag = streamed_text # === 5. Export === os.makedirs(OUTPUT_DIR, exist_ok=True) readme_original_path = Path(f"{OUTPUT_DIR}/README_original.md") readme_fixed_path = Path(f"{OUTPUT_DIR}/README_final.md") diagram_path = Path(f"{OUTPUT_DIR}/ARCHITECTURE.mmd") readme_original_path.write_text(st.session_state.readme) readme_fixed_path.write_text(st.session_state.readme_fixed) diagram_path.write_text(diag) st.success(f"📁 Exported all files → `{OUTPUT_DIR}/`") # Download buttons st.markdown("### 📥 Download Your Files") for label, path, fname, mime in [ ("⬇️ Download Final README.md", readme_fixed_path, "README.md", "text/markdown"), ("⬇️ Download Original README.md", readme_original_path, "README_original.md", "text/markdown"), ("⬇️ Download Architecture Diagram (.mmd)", diagram_path, "ARCHITECTURE.mmd", "text/plain"), ]: with open(path, "rb") as f: st.download_button(label=label, data=f, file_name=fname, mime=mime) st.info("✅ You can also find these files saved in the `output/` folder locally.") # -------------------------------------------------------------- if __name__ == "__main__": main()