__import__('pysqlite3') import sys sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') import streamlit as st import plotly.graph_objects as go import re from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.tools import tool from github import Github st.set_page_config(page_title="AI Code Review Agent", page_icon="🔬", layout="wide", initial_sidebar_state="expanded") st.markdown(""" """, unsafe_allow_html=True) # ─── Session state init ────────────────────────────────────────── for key in ["saved_provider", "saved_model", "saved_api_key", "saved_github_token"]: if key not in st.session_state: st.session_state[key] = "" # ─── LLM factory ───────────────────────────────────────────────── def get_llm(provider, api_key, model_name): if provider == "Groq": from langchain_groq import ChatGroq return ChatGroq(model=model_name, api_key=api_key, temperature=0) elif provider == "Google Gemini": from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, temperature=0) elif provider == "OpenAI": from langchain_openai import ChatOpenAI return ChatOpenAI(model=model_name, api_key=api_key, temperature=0) else: raise ValueError(f"Unknown provider: {provider}") # ─── Diff helpers ───────────────────────────────────────────────── def render_diff_html(patch_text): if not patch_text: return "

No diff available

" html = "" for line in patch_text.split("\n")[:60]: if line.startswith("+") and not line.startswith("+++"): html += f'
+ {line[1:]}
' elif line.startswith("-") and not line.startswith("---"): html += f'
- {line[1:]}
' else: html += f'
  {line}
' return html def build_diff_graph(files_data): fig = go.Figure() fig.add_trace(go.Bar(name="Additions", x=[f["name"] for f in files_data], y=[f["additions"] for f in files_data], marker_color="#3fb950", hovertemplate="%{x}
+%{y} lines")) fig.add_trace(go.Bar(name="Deletions", x=[f["name"] for f in files_data], y=[-f["deletions"] for f in files_data], marker_color="#f85149", hovertemplate="%{x}
-%{y} lines")) fig.update_layout(barmode="relative", paper_bgcolor="#0d1117", plot_bgcolor="#161b22", font=dict(color="#e6edf3"), xaxis=dict(gridcolor="#30363d", tickangle=-30), yaxis=dict(gridcolor="#30363d", title="Lines Changed"), legend=dict(bgcolor="#161b22"), height=350, margin=dict(t=50,b=80,l=40,r=20), title=dict(text="📊 Changes Per File", font=dict(color="#58a6ff", size=16))) return fig def build_severity_chart(high, medium, low): fig = go.Figure(go.Pie(labels=["High","Medium","Low"], values=[max(high,0),max(medium,0),max(low,0)], hole=0.6, marker=dict(colors=["#f85149","#d29922","#3fb950"], line=dict(color="#0d1117", width=2)), hovertemplate="%{label}: %{value} issues")) fig.update_layout(paper_bgcolor="#0d1117", font=dict(color="#e6edf3"), legend=dict(bgcolor="#161b22"), height=300, margin=dict(t=50,b=20,l=20,r=20), title=dict(text="🔴 Issue Severity", font=dict(color="#58a6ff", size=16))) return fig # ─── Sidebar ────────────────────────────────────────────────────── with st.sidebar: st.markdown("## ⚙️ Configuration") provider = st.selectbox( "🤖 LLM Provider", ["Groq", "Google Gemini", "OpenAI"], index=["Groq","Google Gemini","OpenAI"].index(st.session_state.saved_provider) if st.session_state.saved_provider in ["Groq","Google Gemini","OpenAI"] else 0 ) model_name = st.text_input( "📦 Model Name (type any model)", value=st.session_state.saved_model, placeholder="e.g. llama-3.3-70b-versatile" ) api_key = st.text_input( f"🔑 {provider} API Key", value=st.session_state.saved_api_key, type="password", placeholder="Paste your API key..." ) github_token = st.text_input( "🐙 GitHub Token", value=st.session_state.saved_github_token, type="password", placeholder="ghp_..." ) if st.button("💾 Save Configuration"): st.session_state.saved_provider = provider st.session_state.saved_model = model_name st.session_state.saved_api_key = api_key st.session_state.saved_github_token = github_token st.success("✅ Saved for this session!") if st.session_state.saved_api_key: st.markdown('

✓ API key active from session

', unsafe_allow_html=True) if st.session_state.saved_github_token: st.markdown('

✓ GitHub token active from session

', unsafe_allow_html=True) st.markdown("---") st.markdown("**🔄 Agent Flow**") st.markdown("1. Fetch PR metadata\n2. Fetch code diff\n3. Analyze changes\n4. Generate review") st.caption("Built with LangChain + Streamlit") # ─── Main UI ────────────────────────────────────────────────────── st.markdown('

🔬 AI Code Review Agent

', unsafe_allow_html=True) st.markdown('

Autonomous PR analysis — paste any GitHub PR URL below

', unsafe_allow_html=True) st.markdown("---") pr_url = st.text_input("GitHub PR URL", placeholder="https://github.com/owner/repo/pull/123", label_visibility="collapsed") col_btn, col_info = st.columns([1, 3]) with col_btn: analyze_btn = st.button("🚀 Analyze PR") with col_info: st.caption("The agent autonomously decides which tools to call and in what order.") # Use saved values as fallback active_api_key = api_key or st.session_state.saved_api_key active_github_token = github_token or st.session_state.saved_github_token active_model = model_name or st.session_state.saved_model active_provider = provider # ─── Agent execution ────────────────────────────────────────────── if analyze_btn: if not active_api_key or not active_github_token: st.error("⚠️ Enter API keys in the sidebar and click 💾 Save Configuration.") elif not active_model: st.error("⚠️ Enter a model name in the sidebar.") elif not pr_url: st.error("⚠️ Enter a PR URL.") else: @st.cache_resource def load_github_client(token): return Github(token) github_client = load_github_client(active_github_token) files_data = [] @tool def get_pr_metadata(pr_url: str) -> str: """Get the title, description and author of a GitHub PR""" try: parts = pr_url.strip("/").split("/") owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) repo = github_client.get_repo(f"{owner}/{repo_name}") pr = repo.get_pull(pr_number) return f"Title: {pr.title}\nAuthor: {pr.user.login}\nDescription: {pr.body}\nBase: {pr.base.ref} -> Head: {pr.head.ref}" except Exception as e: return f"Error: {str(e)}" @tool def get_pr_diff(pr_url: str) -> str: """Get the code changes (diff) from a GitHub PR""" try: parts = pr_url.strip("/").split("/") owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) repo = github_client.get_repo(f"{owner}/{repo_name}") pr = repo.get_pull(pr_number) diff_text = "" for file in pr.get_files(): files_data.append({ "name": file.filename.split("/")[-1], "additions": file.additions, "deletions": file.deletions, "patch": file.patch or "" }) diff_text += f"\nFile: {file.filename}\n+{file.additions} -{file.deletions}\n" if file.patch: diff_text += file.patch + "\n" return diff_text[:4000] except Exception as e: return f"Error: {str(e)}" @tool def get_file_content(repo_full_name: str, file_path: str) -> str: """Get the full content of a file from a GitHub repository""" try: repo = github_client.get_repo(repo_full_name) content = repo.get_contents(file_path) return content.decoded_content.decode("utf-8")[:2000] except Exception as e: return f"Error: {str(e)}" tools = [get_pr_metadata, get_pr_diff, get_file_content] try: llm = get_llm(active_provider, active_api_key, active_model) except Exception as e: st.error(f"Failed to initialize LLM: {e}") st.stop() prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert code reviewer. When given a GitHub PR URL: 1. Fetch PR metadata to understand its purpose 2. Fetch the PR diff to see what changed 3. Analyze for: bugs, security issues, code quality, naming problems 4. Return structured review with this exact format for each issue: ISSUE: [short title] FILE: [filename] SEVERITY: [High/Medium/Low] PROBLEM: [what is wrong] FIX: [how to fix it] --- IMPORTANT: You MUST put --- on its own line between every single issue. No exceptions. End with a SUMMARY section."""), ("human", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad") ]) agent = create_tool_calling_agent(llm, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, max_iterations=10) status_box = st.empty() with status_box.container(): st.info("🤖 Agent is analyzing the PR... this may take 15–30 seconds.") try: result = agent_executor.invoke({"input": f"Review this PR: {pr_url}"}) status_box.empty() output = result["output"] import gc gc.collect() high = output.upper().count("SEVERITY: HIGH") medium = output.upper().count("SEVERITY: MEDIUM") low = output.upper().count("SEVERITY: LOW") total = high + medium + low st.markdown("### 📊 Review Summary") m1, m2, m3, m4 = st.columns(4) with m1: st.markdown(f'
{total}
Total Issues
', unsafe_allow_html=True) with m2: st.markdown(f'
{high}
High Severity
', unsafe_allow_html=True) with m3: st.markdown(f'
{medium}
Medium Severity
', unsafe_allow_html=True) with m4: st.markdown(f'
{low}
Low Severity
', unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) c1, c2 = st.columns([3, 2]) with c1: if files_data: st.plotly_chart(build_diff_graph(files_data), use_container_width=True) with c2: if total > 0: st.plotly_chart(build_severity_chart(high, medium, low), use_container_width=True) if files_data: st.markdown("### 🗂️ File Changes") for f in files_data: with st.expander(f"📄 {f['name']} (+{f['additions']} / -{f['deletions']})"): st.markdown(render_diff_html(f["patch"]), unsafe_allow_html=True) st.markdown("### 🔍 Detailed Review") blocks = re.split(r'---|\n(?=ISSUE:)', output) cards_rendered = 0 for block in blocks: block = block.strip() if not block or "SUMMARY" in block.upper(): continue if "ISSUE:" in block.upper(): sev = "high" if "SEVERITY: HIGH" in block.upper() else "medium" if "SEVERITY: MEDIUM" in block.upper() else "low" st.markdown( f'
' f'{sev.upper()}' f'
{block}
' f'
', unsafe_allow_html=True ) cards_rendered += 1 if cards_rendered == 0: st.markdown( f'
' f'
{output}
' f'
', unsafe_allow_html=True ) if "SUMMARY" in output.upper(): idx = output.upper().find("SUMMARY") st.markdown("### 📝 Summary") st.markdown( f'
{output[idx:]}
', unsafe_allow_html=True ) except Exception as e: status_box.empty() st.error(f"Agent error: {str(e)}")