__import__('pysqlite3') import sys sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') import streamlit as st import plotly.graph_objects as go import re from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.tools import tool from github import Github st.set_page_config(page_title="AI Code Review Agent", page_icon="🔬", layout="wide", initial_sidebar_state="expanded") st.markdown(""" """, unsafe_allow_html=True) # ─── Session state init ────────────────────────────────────────── for key in ["saved_provider", "saved_model", "saved_api_key", "saved_github_token"]: if key not in st.session_state: st.session_state[key] = "" # ─── LLM factory ───────────────────────────────────────────────── def get_llm(provider, api_key, model_name): if provider == "Groq": from langchain_groq import ChatGroq return ChatGroq(model=model_name, api_key=api_key, temperature=0) elif provider == "Google Gemini": from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, temperature=0) elif provider == "OpenAI": from langchain_openai import ChatOpenAI return ChatOpenAI(model=model_name, api_key=api_key, temperature=0) else: raise ValueError(f"Unknown provider: {provider}") # ─── Diff helpers ───────────────────────────────────────────────── def render_diff_html(patch_text): if not patch_text: return "
No diff available
" html = "" for line in patch_text.split("\n")[:60]: if line.startswith("+") and not line.startswith("+++"): html += f'✓ API key active from session
', unsafe_allow_html=True) if st.session_state.saved_github_token: st.markdown('✓ GitHub token active from session
', unsafe_allow_html=True) st.markdown("---") st.markdown("**🔄 Agent Flow**") st.markdown("1. Fetch PR metadata\n2. Fetch code diff\n3. Analyze changes\n4. Generate review") st.caption("Built with LangChain + Streamlit") # ─── Main UI ────────────────────────────────────────────────────── st.markdown('🔬 AI Code Review Agent
', unsafe_allow_html=True) st.markdown('Autonomous PR analysis — paste any GitHub PR URL below
', unsafe_allow_html=True) st.markdown("---") pr_url = st.text_input("GitHub PR URL", placeholder="https://github.com/owner/repo/pull/123", label_visibility="collapsed") col_btn, col_info = st.columns([1, 3]) with col_btn: analyze_btn = st.button("🚀 Analyze PR") with col_info: st.caption("The agent autonomously decides which tools to call and in what order.") # Use saved values as fallback active_api_key = api_key or st.session_state.saved_api_key active_github_token = github_token or st.session_state.saved_github_token active_model = model_name or st.session_state.saved_model active_provider = provider # ─── Agent execution ────────────────────────────────────────────── if analyze_btn: if not active_api_key or not active_github_token: st.error("⚠️ Enter API keys in the sidebar and click 💾 Save Configuration.") elif not active_model: st.error("⚠️ Enter a model name in the sidebar.") elif not pr_url: st.error("⚠️ Enter a PR URL.") else: @st.cache_resource def load_github_client(token): return Github(token) github_client = load_github_client(active_github_token) files_data = [] @tool def get_pr_metadata(pr_url: str) -> str: """Get the title, description and author of a GitHub PR""" try: parts = pr_url.strip("/").split("/") owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) repo = github_client.get_repo(f"{owner}/{repo_name}") pr = repo.get_pull(pr_number) return f"Title: {pr.title}\nAuthor: {pr.user.login}\nDescription: {pr.body}\nBase: {pr.base.ref} -> Head: {pr.head.ref}" except Exception as e: return f"Error: {str(e)}" @tool def get_pr_diff(pr_url: str) -> str: """Get the code changes (diff) from a GitHub PR""" try: parts = pr_url.strip("/").split("/") owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) repo = github_client.get_repo(f"{owner}/{repo_name}") pr = repo.get_pull(pr_number) diff_text = "" for file in pr.get_files(): files_data.append({ "name": file.filename.split("/")[-1], "additions": file.additions, "deletions": file.deletions, "patch": file.patch or "" }) diff_text += f"\nFile: {file.filename}\n+{file.additions} -{file.deletions}\n" if file.patch: diff_text += file.patch + "\n" return diff_text[:4000] except Exception as e: return f"Error: {str(e)}" @tool def get_file_content(repo_full_name: str, file_path: str) -> str: """Get the full content of a file from a GitHub repository""" try: repo = github_client.get_repo(repo_full_name) content = repo.get_contents(file_path) return content.decoded_content.decode("utf-8")[:2000] except Exception as e: return f"Error: {str(e)}" tools = [get_pr_metadata, get_pr_diff, get_file_content] try: llm = get_llm(active_provider, active_api_key, active_model) except Exception as e: st.error(f"Failed to initialize LLM: {e}") st.stop() prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert code reviewer. When given a GitHub PR URL: 1. Fetch PR metadata to understand its purpose 2. Fetch the PR diff to see what changed 3. Analyze for: bugs, security issues, code quality, naming problems 4. Return structured review with this exact format for each issue: ISSUE: [short title] FILE: [filename] SEVERITY: [High/Medium/Low] PROBLEM: [what is wrong] FIX: [how to fix it] --- IMPORTANT: You MUST put --- on its own line between every single issue. No exceptions. End with a SUMMARY section."""), ("human", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad") ]) agent = create_tool_calling_agent(llm, tools, prompt) agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, max_iterations=10) status_box = st.empty() with status_box.container(): st.info("🤖 Agent is analyzing the PR... this may take 15–30 seconds.") try: result = agent_executor.invoke({"input": f"Review this PR: {pr_url}"}) status_box.empty() output = result["output"] import gc gc.collect() high = output.upper().count("SEVERITY: HIGH") medium = output.upper().count("SEVERITY: MEDIUM") low = output.upper().count("SEVERITY: LOW") total = high + medium + low st.markdown("### 📊 Review Summary") m1, m2, m3, m4 = st.columns(4) with m1: st.markdown(f'{block}'
f'{output}'
f'