aneeb15's picture
Prepare for Hugging Face deployment
66f174c
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import streamlit as st
import plotly.graph_objects as go
import re
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools import tool
from github import Github
st.set_page_config(page_title="AI Code Review Agent", page_icon="πŸ”¬", layout="wide", initial_sidebar_state="expanded")
st.markdown("""
<style>
.stApp { background-color: #0d1117; color: #e6edf3; }
[data-testid="stSidebar"] { background-color: #161b22; border-right: 1px solid #30363d; }
.stTextInput input { background-color: #21262d !important; color: #e6edf3 !important; border: 1px solid #30363d !important; border-radius: 6px !important; }
.stButton button { background: linear-gradient(135deg, #238636, #2ea043) !important; color: white !important; border: none !important; border-radius: 6px !important; font-weight: 600 !important; width: 100%; }
.stButton button:hover { opacity: 0.85 !important; }
.review-card { background: #161b22; border: 1px solid #30363d; border-radius: 12px; padding: 1.2rem 1.5rem; margin-bottom: 1rem; }
.high-card { border-left: 4px solid #f85149; }
.medium-card { border-left: 4px solid #d29922; }
.low-card { border-left: 4px solid #3fb950; }
.badge { display: inline-block; padding: 2px 10px; border-radius: 20px; font-size: 0.75rem; font-weight: 700; margin-bottom: 6px; }
.badge-high { background: #3d1a1a; color: #f85149; border: 1px solid #f85149; }
.badge-medium { background: #2d2008; color: #d29922; border: 1px solid #d29922; }
.badge-low { background: #0d2818; color: #3fb950; border: 1px solid #3fb950; }
.diff-add { background: #0d2818; color: #3fb950; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; }
.diff-remove { background: #3d1a1a; color: #f85149; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; }
.diff-ctx { background: #161b22; color: #8b949e; padding: 2px 8px; font-family: monospace; font-size: 0.85rem; }
.main-title { font-size: 2rem; font-weight: 800; background: linear-gradient(90deg, #58a6ff, #3fb950); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
.subtitle { color: #8b949e; font-size: 0.95rem; }
.metric-box { background: #161b22; border: 1px solid #30363d; border-radius: 10px; padding: 1rem; text-align: center; }
.metric-value { font-size: 2rem; font-weight: 800; }
.metric-label { color: #8b949e; font-size: 0.8rem; }
.saved-badge { color: #3fb950; font-size: 0.8rem; font-weight: 600; }
</style>
""", unsafe_allow_html=True)
# ─── Session state init ──────────────────────────────────────────
for key in ["saved_provider", "saved_model", "saved_api_key", "saved_github_token"]:
if key not in st.session_state:
st.session_state[key] = ""
# ─── LLM factory ─────────────────────────────────────────────────
def get_llm(provider, api_key, model_name):
if provider == "Groq":
from langchain_groq import ChatGroq
return ChatGroq(model=model_name, api_key=api_key, temperature=0)
elif provider == "Google Gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, temperature=0)
elif provider == "OpenAI":
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model_name, api_key=api_key, temperature=0)
else:
raise ValueError(f"Unknown provider: {provider}")
# ─── Diff helpers ─────────────────────────────────────────────────
def render_diff_html(patch_text):
if not patch_text:
return "<p style='color:#8b949e'>No diff available</p>"
html = ""
for line in patch_text.split("\n")[:60]:
if line.startswith("+") and not line.startswith("+++"):
html += f'<div class="diff-add">+ {line[1:]}</div>'
elif line.startswith("-") and not line.startswith("---"):
html += f'<div class="diff-remove">- {line[1:]}</div>'
else:
html += f'<div class="diff-ctx">&nbsp; {line}</div>'
return html
def build_diff_graph(files_data):
fig = go.Figure()
fig.add_trace(go.Bar(name="Additions", x=[f["name"] for f in files_data], y=[f["additions"] for f in files_data], marker_color="#3fb950", hovertemplate="<b>%{x}</b><br>+%{y} lines<extra></extra>"))
fig.add_trace(go.Bar(name="Deletions", x=[f["name"] for f in files_data], y=[-f["deletions"] for f in files_data], marker_color="#f85149", hovertemplate="<b>%{x}</b><br>-%{y} lines<extra></extra>"))
fig.update_layout(barmode="relative", paper_bgcolor="#0d1117", plot_bgcolor="#161b22", font=dict(color="#e6edf3"), xaxis=dict(gridcolor="#30363d", tickangle=-30), yaxis=dict(gridcolor="#30363d", title="Lines Changed"), legend=dict(bgcolor="#161b22"), height=350, margin=dict(t=50,b=80,l=40,r=20), title=dict(text="πŸ“Š Changes Per File", font=dict(color="#58a6ff", size=16)))
return fig
def build_severity_chart(high, medium, low):
fig = go.Figure(go.Pie(labels=["High","Medium","Low"], values=[max(high,0),max(medium,0),max(low,0)], hole=0.6, marker=dict(colors=["#f85149","#d29922","#3fb950"], line=dict(color="#0d1117", width=2)), hovertemplate="<b>%{label}</b>: %{value} issues<extra></extra>"))
fig.update_layout(paper_bgcolor="#0d1117", font=dict(color="#e6edf3"), legend=dict(bgcolor="#161b22"), height=300, margin=dict(t=50,b=20,l=20,r=20), title=dict(text="πŸ”΄ Issue Severity", font=dict(color="#58a6ff", size=16)))
return fig
# ─── Sidebar ──────────────────────────────────────────────────────
with st.sidebar:
st.markdown("## βš™οΈ Configuration")
provider = st.selectbox(
"πŸ€– LLM Provider",
["Groq", "Google Gemini", "OpenAI"],
index=["Groq","Google Gemini","OpenAI"].index(st.session_state.saved_provider) if st.session_state.saved_provider in ["Groq","Google Gemini","OpenAI"] else 0
)
model_name = st.text_input(
"πŸ“¦ Model Name (type any model)",
value=st.session_state.saved_model,
placeholder="e.g. llama-3.3-70b-versatile"
)
api_key = st.text_input(
f"πŸ”‘ {provider} API Key",
value=st.session_state.saved_api_key,
type="password",
placeholder="Paste your API key..."
)
github_token = st.text_input(
"πŸ™ GitHub Token",
value=st.session_state.saved_github_token,
type="password",
placeholder="ghp_..."
)
if st.button("πŸ’Ύ Save Configuration"):
st.session_state.saved_provider = provider
st.session_state.saved_model = model_name
st.session_state.saved_api_key = api_key
st.session_state.saved_github_token = github_token
st.success("βœ… Saved for this session!")
if st.session_state.saved_api_key:
st.markdown('<p class="saved-badge">βœ“ API key active from session</p>', unsafe_allow_html=True)
if st.session_state.saved_github_token:
st.markdown('<p class="saved-badge">βœ“ GitHub token active from session</p>', unsafe_allow_html=True)
st.markdown("---")
st.markdown("**πŸ”„ Agent Flow**")
st.markdown("1. Fetch PR metadata\n2. Fetch code diff\n3. Analyze changes\n4. Generate review")
st.caption("Built with LangChain + Streamlit")
# ─── Main UI ──────────────────────────────────────────────────────
st.markdown('<p class="main-title">πŸ”¬ AI Code Review Agent</p>', unsafe_allow_html=True)
st.markdown('<p class="subtitle">Autonomous PR analysis β€” paste any GitHub PR URL below</p>', unsafe_allow_html=True)
st.markdown("---")
pr_url = st.text_input("GitHub PR URL", placeholder="https://github.com/owner/repo/pull/123", label_visibility="collapsed")
col_btn, col_info = st.columns([1, 3])
with col_btn:
analyze_btn = st.button("πŸš€ Analyze PR")
with col_info:
st.caption("The agent autonomously decides which tools to call and in what order.")
# Use saved values as fallback
active_api_key = api_key or st.session_state.saved_api_key
active_github_token = github_token or st.session_state.saved_github_token
active_model = model_name or st.session_state.saved_model
active_provider = provider
# ─── Agent execution ──────────────────────────────────────────────
if analyze_btn:
if not active_api_key or not active_github_token:
st.error("⚠️ Enter API keys in the sidebar and click πŸ’Ύ Save Configuration.")
elif not active_model:
st.error("⚠️ Enter a model name in the sidebar.")
elif not pr_url:
st.error("⚠️ Enter a PR URL.")
else:
@st.cache_resource
def load_github_client(token):
return Github(token)
github_client = load_github_client(active_github_token)
files_data = []
@tool
def get_pr_metadata(pr_url: str) -> str:
"""Get the title, description and author of a GitHub PR"""
try:
parts = pr_url.strip("/").split("/")
owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1])
repo = github_client.get_repo(f"{owner}/{repo_name}")
pr = repo.get_pull(pr_number)
return f"Title: {pr.title}\nAuthor: {pr.user.login}\nDescription: {pr.body}\nBase: {pr.base.ref} -> Head: {pr.head.ref}"
except Exception as e:
return f"Error: {str(e)}"
@tool
def get_pr_diff(pr_url: str) -> str:
"""Get the code changes (diff) from a GitHub PR"""
try:
parts = pr_url.strip("/").split("/")
owner, repo_name, pr_number = parts[-4], parts[-3], int(parts[-1])
repo = github_client.get_repo(f"{owner}/{repo_name}")
pr = repo.get_pull(pr_number)
diff_text = ""
for file in pr.get_files():
files_data.append({
"name": file.filename.split("/")[-1],
"additions": file.additions,
"deletions": file.deletions,
"patch": file.patch or ""
})
diff_text += f"\nFile: {file.filename}\n+{file.additions} -{file.deletions}\n"
if file.patch:
diff_text += file.patch + "\n"
return diff_text[:4000]
except Exception as e:
return f"Error: {str(e)}"
@tool
def get_file_content(repo_full_name: str, file_path: str) -> str:
"""Get the full content of a file from a GitHub repository"""
try:
repo = github_client.get_repo(repo_full_name)
content = repo.get_contents(file_path)
return content.decoded_content.decode("utf-8")[:2000]
except Exception as e:
return f"Error: {str(e)}"
tools = [get_pr_metadata, get_pr_diff, get_file_content]
try:
llm = get_llm(active_provider, active_api_key, active_model)
except Exception as e:
st.error(f"Failed to initialize LLM: {e}")
st.stop()
prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert code reviewer. When given a GitHub PR URL:
1. Fetch PR metadata to understand its purpose
2. Fetch the PR diff to see what changed
3. Analyze for: bugs, security issues, code quality, naming problems
4. Return structured review with this exact format for each issue:
ISSUE: [short title]
FILE: [filename]
SEVERITY: [High/Medium/Low]
PROBLEM: [what is wrong]
FIX: [how to fix it]
---
IMPORTANT: You MUST put --- on its own line between every single issue. No exceptions.
End with a SUMMARY section."""),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad")
])
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, max_iterations=10)
status_box = st.empty()
with status_box.container():
st.info("πŸ€– Agent is analyzing the PR... this may take 15–30 seconds.")
try:
result = agent_executor.invoke({"input": f"Review this PR: {pr_url}"})
status_box.empty()
output = result["output"]
import gc
gc.collect()
high = output.upper().count("SEVERITY: HIGH")
medium = output.upper().count("SEVERITY: MEDIUM")
low = output.upper().count("SEVERITY: LOW")
total = high + medium + low
st.markdown("### πŸ“Š Review Summary")
m1, m2, m3, m4 = st.columns(4)
with m1: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#58a6ff">{total}</div><div class="metric-label">Total Issues</div></div>', unsafe_allow_html=True)
with m2: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#f85149">{high}</div><div class="metric-label">High Severity</div></div>', unsafe_allow_html=True)
with m3: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#d29922">{medium}</div><div class="metric-label">Medium Severity</div></div>', unsafe_allow_html=True)
with m4: st.markdown(f'<div class="metric-box"><div class="metric-value" style="color:#3fb950">{low}</div><div class="metric-label">Low Severity</div></div>', unsafe_allow_html=True)
st.markdown("<br>", unsafe_allow_html=True)
c1, c2 = st.columns([3, 2])
with c1:
if files_data:
st.plotly_chart(build_diff_graph(files_data), use_container_width=True)
with c2:
if total > 0:
st.plotly_chart(build_severity_chart(high, medium, low), use_container_width=True)
if files_data:
st.markdown("### πŸ—‚οΈ File Changes")
for f in files_data:
with st.expander(f"πŸ“„ {f['name']} (+{f['additions']} / -{f['deletions']})"):
st.markdown(render_diff_html(f["patch"]), unsafe_allow_html=True)
st.markdown("### πŸ” Detailed Review")
blocks = re.split(r'---|\n(?=ISSUE:)', output)
cards_rendered = 0
for block in blocks:
block = block.strip()
if not block or "SUMMARY" in block.upper():
continue
if "ISSUE:" in block.upper():
sev = "high" if "SEVERITY: HIGH" in block.upper() else "medium" if "SEVERITY: MEDIUM" in block.upper() else "low"
st.markdown(
f'<div class="review-card {sev}-card">'
f'<span class="badge badge-{sev}">{sev.upper()}</span>'
f'<pre style="white-space:pre-wrap;color:#e6edf3;background:transparent;border:none;padding:0;margin:0">{block}</pre>'
f'</div>',
unsafe_allow_html=True
)
cards_rendered += 1
if cards_rendered == 0:
st.markdown(
f'<div class="review-card low-card">'
f'<pre style="white-space:pre-wrap;color:#e6edf3;background:transparent;border:none;padding:0;margin:0">{output}</pre>'
f'</div>',
unsafe_allow_html=True
)
if "SUMMARY" in output.upper():
idx = output.upper().find("SUMMARY")
st.markdown("### πŸ“ Summary")
st.markdown(
f'<div class="review-card" style="border-left:4px solid #58a6ff">{output[idx:]}</div>',
unsafe_allow_html=True
)
except Exception as e:
status_box.empty()
st.error(f"Agent error: {str(e)}")