Spaces:
Sleeping
Sleeping
| __import__('pysqlite3') | |
| import sys | |
| sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| import re | |
| from langchain.agents import AgentExecutor, create_tool_calling_agent | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.tools import tool | |
| from github import Github | |
| st.set_page_config(page_title="AI Code Review Agent", page_icon="π¬", layout="wide", initial_sidebar_state="expanded") | |
| st.markdown(""" | |
| <style> | |
| .stApp { background-color: #0d1117; color: #e6edf3; } | |
| [data-testid="stSidebar"] { background-color: #161b22; border-right: 1px solid #30363d; } | |
| .stTextInput input { background-color: #21262d !important; color: #e6edf3 !important; border: 1px solid #30363d !important; border-radius: 6px !important; } | |
| .stButton button { background: linear-gradient(135deg, #238636, #2ea043) !important; color: white !important; border: none !important; border-radius: 6px !important; font-weight: 600 !important; width: 100%; } | |
| .stButton button:hover { opacity: 0.85 !important; } | |
| .review-card { background: #161b22; border: 1px solid #30363d; border-radius: 12px; padding: 1.2rem 1.5rem; margin-bottom: 1rem; } | |
| .high-card { border-left: 4px solid #f85149; } | |
| .medium-card { border-left: 4px solid #d29922; } | |
| .low-card { border-left: 4px solid #3fb950; } | |
| .badge { display: inline-block; padding: 2px 10px; border-radius: 20px; font-size: 0.75rem; font-weight: 700; margin-bottom: 6px; } | |
| .badge-high { background: #3d1a1a; color: #f85149; border: 1px solid #f85149; } | |
| .badge-medium { background: #2d2008; color: #d29922; border: 1px solid #d29922; } | |
| .badge-low { background: #0d2818; color: #3fb950; border: 1px solid #3fb950; } | |
| .diff-add { background: #0d2818; color: #3fb950; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; } | |
| .diff-remove { background: #3d1a1a; color: #f85149; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; } | |
| .diff-ctx { background: #161b22; color: #8b949e; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; } | |
| .main-title { font-size: 2rem; font-weight: 800; background: linear-gradient(90deg, #58a6ff, #3fb950); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } | |
| .subtitle { color: #8b949e; font-size: 0.95rem; } | |
| .metric-box { background: #161b22; border: 1px solid #30363d; border-radius: 10px; padding: 1rem; text-align: center; } | |
| .metric-value { font-size: 2rem; font-weight: 800; } | |
| .metric-label { color: #8b949e; font-size: 0.8rem; } | |
| .saved-badge { color: #3fb950; font-size: 0.8rem; font-weight: 600; } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # βββ Session state init ββββββββββββββββββββββββββββββββββββββββββ | |
| for key in ["saved_provider", "saved_model", "saved_api_key", "saved_github_token"]: | |
| if key not in st.session_state: | |
| st.session_state[key] = "" | |
| # βββ LLM factory βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_llm(provider, api_key, model_name): | |
| if provider == "Groq": | |
| from langchain_groq import ChatGroq | |
| return ChatGroq(model=model_name, api_key=api_key, temperature=0) | |
| elif provider == "Google Gemini": | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, temperature=0) | |
| elif provider == "OpenAI": | |
| from langchain_openai import ChatOpenAI | |
| return ChatOpenAI(model=model_name, api_key=api_key, temperature=0) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| # βββ Diff helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_diff_html(patch_text): | |
| if not patch_text: | |
| return "<p style='color:#8b949e'>No diff available</p>" | |
| html = "" | |
| for line in patch_text.split("\n")[:60]: | |
| if line.startswith("+") and not line.startswith("+++"): | |
| html += f'<div class="diff-add">+ {line[1:]}</div>' | |
| elif line.startswith("-") and not line.startswith("---"): | |
| html += f'<div class="diff-remove">- {line[1:]}</div>' | |
| else: | |
| html += f'<div class="diff-ctx"> {line}</div>' | |
| return html | |
| def build_diff_graph(files_data): | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar(name="Additions", x=[f["name"] for f in files_data], y=[f["additions"] for f in files_data], marker_color="#3fb950", hovertemplate="<b>%{x}</b><br>+%{y} lines<extra></extra>")) | |
| fig.add_trace(go.Bar(name="Deletions", x=[f["name"] for f in files_data], y=[-f["deletions"] for f in files_data], marker_color="#f85149", hovertemplate="<b>%{x}</b><br>-%{y} lines<extra></extra>")) | |
| fig.update_layout(barmode="relative", paper_bgcolor="#0d1117", plot_bgcolor="#161b22", font=dict(color="#e6edf3"), xaxis=dict(gridcolor="#30363d", tickangle=-30), yaxis=dict(gridcolor="#30363d", title="Lines Changed"), legend=dict(bgcolor="#161b22"), height=350, margin=dict(t=50,b=80,l=40,r=20), title=dict(text="π Changes Per File", font=dict(color="#58a6ff", size=16))) | |
| return fig | |
| def build_severity_chart(high, medium, low): | |
| fig = go.Figure(go.Pie(labels=["High","Medium","Low"], values=[max(high,0),max(medium,0),max(low,0)], hole=0.6, marker=dict(colors=["#f85149","#d29922","#3fb950"], line=dict(color="#0d1117", width=2)), hovertemplate="<b>%{label}</b>: %{value} issues<extra></extra>")) | |
| fig.update_layout(paper_bgcolor="#0d1117", font=dict(color="#e6edf3"), legend=dict(bgcolor="#161b22"), height=300, margin=dict(t=50,b=20,l=20,r=20), title=dict(text="π΄ Issue Severity", font=dict(color="#58a6ff", size=16))) | |
| return fig | |
| # βββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("## βοΈ Configuration") | |
| provider = st.selectbox( | |
| "π€ LLM Provider", | |
| ["Groq", "Google Gemini", "OpenAI"], | |
| index=["Groq","Google Gemini","OpenAI"].index(st.session_state.saved_provider) if st.session_state.saved_provider in ["Groq","Google Gemini","OpenAI"] else 0 | |
| ) | |
| model_name = st.text_input( | |
| "π¦ Model Name (type any model)", | |
| value=st.session_state.saved_model, | |
| placeholder="e.g. llama-3.3-70b-versatile" | |
| ) | |
| api_key = st.text_input( | |
| f"π {provider} API Key", | |
| value=st.session_state.saved_api_key, | |
| type="password", | |
| placeholder="Paste your API key..." | |
| ) | |
| github_token = st.text_input( | |
| "π GitHub Token", | |
| value=st.session_state.saved_github_token, | |
| type="password", | |
| placeholder="ghp_..." | |
| ) | |
| if st.button("πΎ Save Configuration"): | |
| st.session_state.saved_provider = provider | |
| st.session_state.saved_model = model_name | |
| st.session_state.saved_api_key = api_key | |
| st.session_state.saved_github_token = github_token | |
| st.success("β Saved for this session!") | |
| if st.session_state.saved_api_key: | |
| st.markdown('<p class="saved-badge">β API key active from session</p>', unsafe_allow_html=True) | |
| if st.session_state.saved_github_token: | |
| st.markdown('<p class="saved-badge">β GitHub token active from session</p>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown("**π Agent Flow**") | |
| st.markdown("1. Fetch PR metadata\n2. Fetch code diff\n3. Analyze changes\n4. Generate review") | |
| st.caption("Built with LangChain + Streamlit") | |
| # βββ Main UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown('<p class="main-title">π¬ AI Code Review Agent</p>', unsafe_allow_html=True) | |
| st.markdown('<p class="subtitle">Autonomous PR analysis β paste any GitHub PR URL below</p>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| pr_url = st.text_input("GitHub PR URL", placeholder="https://github.com/owner/repo/pull/123", label_visibility="collapsed") | |
| col_btn, col_info = st.columns([1, 3]) | |
| with col_btn: | |
| analyze_btn = st.button("π Analyze PR") | |
| with col_info: | |
| st.caption("The agent autonomously decides which tools to call and in what order.") | |
| # Use saved values as fallback | |
| active_api_key = api_key or st.session_state.saved_api_key | |
| active_github_token = github_token or st.session_state.saved_github_token | |
| active_model = model_name or st.session_state.saved_model | |
| active_provider = provider | |
| # βββ Agent execution ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if analyze_btn: | |
| if not active_api_key or not active_github_token: | |
| st.error("β οΈ Enter API keys in the sidebar and click πΎ Save Configuration.") | |
| elif not active_model: | |
| st.error("β οΈ Enter a model name in the sidebar.") | |
| elif not pr_url: | |
| st.error("β οΈ Enter a PR URL.") | |
| else: | |
| def load_github_client(token): | |
| return Github(token) | |
| github_client = load_github_client(active_github_token) | |
| files_data = [] | |
| def get_pr_metadata(pr_url: str) -> str: | |
| """Get the title, description and author of a GitHub PR""" | |
| try: | |
| parts = pr_url.strip("/").split("/") | |
| owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) | |
| repo = github_client.get_repo(f"{owner}/{repo_name}") | |
| pr = repo.get_pull(pr_number) | |
| return f"Title: {pr.title}\nAuthor: {pr.user.login}\nDescription: {pr.body}\nBase: {pr.base.ref} -> Head: {pr.head.ref}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def get_pr_diff(pr_url: str) -> str: | |
| """Get the code changes (diff) from a GitHub PR""" | |
| try: | |
| parts = pr_url.strip("/").split("/") | |
| owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1]) | |
| repo = github_client.get_repo(f"{owner}/{repo_name}") | |
| pr = repo.get_pull(pr_number) | |
| diff_text = "" | |
| for file in pr.get_files(): | |
| files_data.append({ | |
| "name": file.filename.split("/")[-1], | |
| "additions": file.additions, | |
| "deletions": file.deletions, | |
| "patch": file.patch or "" | |
| }) | |
| diff_text += f"\nFile: {file.filename}\n+{file.additions} -{file.deletions}\n" | |
| if file.patch: | |
| diff_text += file.patch + "\n" | |
| return diff_text[:4000] | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def get_file_content(repo_full_name: str, file_path: str) -> str: | |
| """Get the full content of a file from a GitHub repository""" | |
| try: | |
| repo = github_client.get_repo(repo_full_name) | |
| content = repo.get_contents(file_path) | |
| return content.decoded_content.decode("utf-8")[:2000] | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| tools = [get_pr_metadata, get_pr_diff, get_file_content] | |
| try: | |
| llm = get_llm(active_provider, active_api_key, active_model) | |
| except Exception as e: | |
| st.error(f"Failed to initialize LLM: {e}") | |
| st.stop() | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert code reviewer. When given a GitHub PR URL: | |
| 1. Fetch PR metadata to understand its purpose | |
| 2. Fetch the PR diff to see what changed | |
| 3. Analyze for: bugs, security issues, code quality, naming problems | |
| 4. Return structured review with this exact format for each issue: | |
| ISSUE: [short title] | |
| FILE: [filename] | |
| SEVERITY: [High/Medium/Low] | |
| PROBLEM: [what is wrong] | |
| FIX: [how to fix it] | |
| --- | |
| IMPORTANT: You MUST put --- on its own line between every single issue. No exceptions. | |
| End with a SUMMARY section."""), | |
| ("human", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad") | |
| ]) | |
| agent = create_tool_calling_agent(llm, tools, prompt) | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, max_iterations=10) | |
| status_box = st.empty() | |
| with status_box.container(): | |
| st.info("π€ Agent is analyzing the PR... this may take 15β30 seconds.") | |
| try: | |
| result = agent_executor.invoke({"input": f"Review this PR: {pr_url}"}) | |
| status_box.empty() | |
| output = result["output"] | |
| import gc | |
| gc.collect() | |
| high = output.upper().count("SEVERITY: HIGH") | |
| medium = output.upper().count("SEVERITY: MEDIUM") | |
| low = output.upper().count("SEVERITY: LOW") | |
| total = high + medium + low | |
| st.markdown("### π Review Summary") | |
| m1, m2, m3, m4 = st.columns(4) | |
| with m1: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#58a6ff">{total}</div><div class="metric-label">Total Issues</div></div>', unsafe_allow_html=True) | |
| with m2: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#f85149">{high}</div><div class="metric-label">High Severity</div></div>', unsafe_allow_html=True) | |
| with m3: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#d29922">{medium}</div><div class="metric-label">Medium Severity</div></div>', unsafe_allow_html=True) | |
| with m4: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#3fb950">{low}</div><div class="metric-label">Low Severity</div></div>', unsafe_allow_html=True) | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| c1, c2 = st.columns([3, 2]) | |
| with c1: | |
| if files_data: | |
| st.plotly_chart(build_diff_graph(files_data), use_container_width=True) | |
| with c2: | |
| if total > 0: | |
| st.plotly_chart(build_severity_chart(high, medium, low), use_container_width=True) | |
| if files_data: | |
| st.markdown("### ποΈ File Changes") | |
| for f in files_data: | |
| with st.expander(f"π {f['name']} (+{f['additions']} / -{f['deletions']})"): | |
| st.markdown(render_diff_html(f["patch"]), unsafe_allow_html=True) | |
| st.markdown("### π Detailed Review") | |
| blocks = re.split(r'---|\n(?=ISSUE:)', output) | |
| cards_rendered = 0 | |
| for block in blocks: | |
| block = block.strip() | |
| if not block or "SUMMARY" in block.upper(): | |
| continue | |
| if "ISSUE:" in block.upper(): | |
| sev = "high" if "SEVERITY: HIGH" in block.upper() else "medium" if "SEVERITY: MEDIUM" in block.upper() else "low" | |
| st.markdown( | |
| f'<div class="review-card {sev}-card">' | |
| f'<span class="badge badge-{sev}">{sev.upper()}</span>' | |
| f'<pre style="white-space:pre-wrap;color:#e6edf3;background:transparent;border:none;padding:0;margin:0">{block}</pre>' | |
| f'</div>', | |
| unsafe_allow_html=True | |
| ) | |
| cards_rendered += 1 | |
| if cards_rendered == 0: | |
| st.markdown( | |
| f'<div class="review-card low-card">' | |
| f'<pre style="white-space:pre-wrap;color:#e6edf3;background:transparent;border:none;padding:0;margin:0">{output}</pre>' | |
| f'</div>', | |
| unsafe_allow_html=True | |
| ) | |
| if "SUMMARY" in output.upper(): | |
| idx = output.upper().find("SUMMARY") | |
| st.markdown("### π Summary") | |
| st.markdown( | |
| f'<div class="review-card" style="border-left:4px solid #58a6ff">{output[idx:]}</div>', | |
| unsafe_allow_html=True | |
| ) | |
| except Exception as e: | |
| status_box.empty() | |
| st.error(f"Agent error: {str(e)}") |