Your Name
made some changes in code
c080da9
Raw
History Blame Contribute Delete
13.6 kB
import os
import shutil
import time
from langchain_google_genai import ChatGoogleGenerativeAI
import git
import streamlit as st
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_message_histories import ChatMessageHistory
load_dotenv()
# ── Page config ───────────────────────────────────────────
st.set_page_config(
page_title="AI Codebase Explainer",
page_icon="πŸ”",
layout="wide"
)
# ── Initialize session state ──────────────────────────────
if "vectorstore" not in st.session_state:
st.session_state.vectorstore = None
st.session_state.history = ChatMessageHistory()
st.session_state.messages = []
st.session_state.repo_name = ""
st.session_state.indexed = False
st.session_state.stats = {}
# ── Load models ───────────────────────────────────────────
@st.cache_resource
@st.cache_resource
def load_models():
# Try Groq first β€” fastest
try:
from langchain_groq import ChatGroq
llm = ChatGroq(
model="llama-3.1-8b-instant",
temperature=0,
max_tokens=500
)
# Test if it works
llm.invoke("hi")
print("Using Groq")
except Exception:
# Fallback to Gemini
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
max_output_tokens=500
)
print("Using Gemini fallback")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
return llm, embeddings
llm, embeddings = load_models()
parser = StrOutputParser()
# ── Core functions ────────────────────────────────────────
def clone_repo(github_url):
repo_name = github_url.rstrip("/").split("/")[-1]
clone_path = f"cloned_repos/{repo_name}"
if os.path.exists(clone_path):
shutil.rmtree(clone_path)
os.makedirs("cloned_repos", exist_ok=True)
git.Repo.clone_from(github_url, clone_path)
return clone_path, repo_name
def load_code_files(repo_path):
extensions = ["py", "js", "ts", "md", "txt", "json", "css", "html"]
all_docs = []
for ext in extensions:
try:
loader = DirectoryLoader(
repo_path,
glob=f"**/*.{ext}",
loader_cls=TextLoader,
loader_kwargs={"encoding": "utf-8"},
silent_errors=True
)
docs = loader.load()
for doc in docs:
doc.metadata["file_name"] = os.path.basename(
doc.metadata.get("source", "unknown")
)
doc.metadata["file_type"] = ext
all_docs.extend(docs)
except Exception:
continue
return all_docs
def split_and_index(all_docs):
from langchain_text_splitters import Language
EXTENSION_TO_LANGUAGE = {
"py": Language.PYTHON,
"js": Language.JS,
"ts": Language.TS,
"jsx": Language.JS,
"tsx": Language.TS,
"java": Language.JAVA,
"cpp": Language.CPP,
"c": Language.CPP,
"go": Language.GO,
"rb": Language.RUBY,
"rs": Language.RUST,
"md": Language.MARKDOWN,
}
all_chunks = []
for doc in all_docs:
ext = doc.metadata.get("file_type", "").lower()
language = EXTENSION_TO_LANGUAGE.get(ext)
if language:
splitter = RecursiveCharacterTextSplitter.from_language(
language=language,
chunk_size=2000,
chunk_overlap=300
)
else:
splitter = RecursiveCharacterTextSplitter(
chunk_size=1500,
chunk_overlap=200
)
all_chunks.extend(splitter.split_documents([doc]))
vectorstore = Chroma.from_documents(
documents=all_chunks,
embedding=embeddings
)
return vectorstore, len(all_docs), len(all_chunks)
def ask_question(question, vectorstore, history):
retriever = vectorstore.as_retriever(
search_type="mmr",
search_kwargs={"k": 8, "fetch_k": 20, "lambda_mult": 0.7}
)
docs = retriever.invoke(question)
context = "\n\n".join([
f"# File: {d.metadata['file_name']}\n{d.page_content}"
for d in docs
])
prompt = ChatPromptTemplate.from_messages([
("system",
"You are an expert code analyst for a GitHub repository.\n"
"Answer questions using the retrieved code chunks below.\n\n"
"Rules:\n"
"- Always name the exact file where you found the answer\n"
"- Prioritize source code files (.py, .js, .ts) over documentation (README, conf.py, setup.py)\n"
"- If implementation is spread across files, piece it together\n"
"- If you see a method name or partial logic, explain what it does\n"
"- NEVER say 'not in codebase' if you found related code or methods\n"
"- Give specific details: method names, parameters, logic flow\n"
"- If truly nothing relevant exists, say what you DID find instead\n\n"
"Code context:\n{context}"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}")
])
chain = prompt | llm | parser
for attempt in range(3):
try:
time.sleep(0.5)
response = chain.invoke({
"context": context,
"history": history.messages,
"question": question
})
history.add_user_message(question)
history.add_ai_message(response)
return response
except Exception as e:
err = str(e).lower()
if "429" in err or "rate limit" in err:
if attempt < 2:
time.sleep(10 * (attempt + 1))
continue
return "⚠️ Rate limit hit. Resets midnight UTC."
elif "401" in err or "invalid api key" in err:
return "⚠️ Invalid API key. Update GROQ_API_KEY in .env"
elif "timeout" in err or "connection" in err:
if attempt < 2:
time.sleep(5)
continue
return "⚠️ Connection timed out. Try again."
else:
return f"⚠️ Error: {str(e)}"
# ── UI ────────────────────────────────────────────────────
st.title("πŸ” AI Codebase Explainer")
st.markdown(
"Paste any **public GitHub repo URL** β€” "
"ask questions about the code in plain English"
)
st.divider()
with st.sidebar:
st.header(" Load Repository")
# ── Quick fill buttons ────────────────────────────────
st.markdown("**Try these:**")
col1, col2 = st.columns(2)
with col1:
if st.button("Spoon-Knife", use_container_width=True):
st.session_state["prefill_url"] = "https://github.com/octocat/Spoon-Knife"
with col2:
if st.button("Flask", use_container_width=True):
st.session_state["prefill_url"] = "https://github.com/pallets/flask"
# ── URL input ─────────────────────────────────────────
default_url = st.session_state.get("prefill_url", "")
github_url = st.text_input(
"GitHub Repository URL",
value=default_url,
placeholder="https://github.com/username/repo"
)
# ── Load button ───────────────────────────────────────
if github_url:
if st.button(
"Load & Index",
use_container_width=True,
type="primary"
):
try:
st.session_state.messages = []
st.session_state.history = ChatMessageHistory()
st.session_state.indexed = False
with st.spinner("Step 1/3: Cloning repository..."):
clone_path, repo_name = clone_repo(github_url)
with st.spinner("Step 2/3: Loading files..."):
all_docs = load_code_files(clone_path)
if not all_docs:
st.error("No readable files found!")
st.stop()
with st.spinner(f"Step 3/3: Indexing {len(all_docs)} files..."):
vectorstore, n_files, n_chunks = split_and_index(all_docs)
st.session_state.vectorstore = vectorstore
st.session_state.repo_name = repo_name
st.session_state.indexed = True
st.session_state.stats = {
"files" : n_files,
"chunks": n_chunks
}
# Clear prefill after successful load
if "prefill_url" in st.session_state:
del st.session_state["prefill_url"]
st.success("βœ… Ready!")
except Exception as e:
st.error(f"Error: {str(e)}")
if st.session_state.indexed:
st.divider()
st.metric("Files", st.session_state.stats["files"])
st.metric("Chunks", st.session_state.stats["chunks"])
st.markdown(f"**Repo:** {st.session_state.repo_name}")
if st.button("πŸ”„ New Repo", use_container_width=True):
st.session_state.vectorstore = None
st.session_state.indexed = False
st.session_state.messages = []
st.session_state.history = ChatMessageHistory()
if "prefill_url" in st.session_state:
del st.session_state["prefill_url"]
st.rerun()
# ── Main area ─────────────────────────────────────────────
if not st.session_state.indexed:
col1, col2, col3 = st.columns(3)
with col1:
st.info("**Step 1**\nPaste GitHub URL")
with col2:
st.info("**Step 2**\nClick Load & Index")
with col3:
st.info("**Step 3**\nAsk questions")
st.divider()
st.markdown("### Example questions")
examples = [
"What does this project do?",
"What are the main files?",
"How does authentication work?",
"Where is the database code?",
"How do I add a new feature?",
"What dependencies does it use?",
]
col1, col2 = st.columns(2)
for i, q in enumerate(examples):
with col1 if i % 2 == 0 else col2:
st.markdown(f"πŸ’¬ *{q}*")
else:
st.subheader(f"πŸ’¬ Ask about `{st.session_state.repo_name}`")
# Quick question buttons
st.markdown("**Quick questions:**")
quick = [
"What does this project do?",
"What are the main files?",
"What dependencies does it use?",
"How is the code structured?",
]
cols = st.columns(4)
for i, q in enumerate(quick):
with cols[i]:
if st.button(q, use_container_width=True, key=f"quick{i}"):
st.session_state.messages.append({
"role": "user", "content": q
})
with st.spinner("Reading code..."):
response = ask_question(
q,
st.session_state.vectorstore,
st.session_state.history
)
st.session_state.messages.append({
"role": "assistant", "content": response
})
st.rerun()
st.divider()
# Chat history
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Chat input
if question := st.chat_input("Ask anything about the code..."):
st.session_state.messages.append({
"role": "user", "content": question
})
with st.chat_message("user"):
st.markdown(question)
with st.chat_message("assistant"):
with st.spinner("Reading code..."):
response = ask_question(
question,
st.session_state.vectorstore,
st.session_state.history
)
st.markdown(response)
st.session_state.messages.append({
"role": "assistant", "content": response
})
st.rerun()