DaCrow13
Deploy to HF Spaces (Clean)
225af6a
import os
from typing import Dict, List
import pandas as pd
import requests
import streamlit as st
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")
# Page config
st.set_page_config(
page_title="GitHub Skill Classifier", layout="wide", initial_sidebar_state="expanded"
)
st.markdown(
"""
<style>
.main-header {
font-size: 2.5rem;
color: #1f77b4;
text-align: center;
margin-bottom: 2rem;
}
.skill-card {
padding: 1rem;
border-radius: 0.5rem;
border-left: 4px solid #1f77b4;
background-color: #f0f2f6;
margin-bottom: 0.5rem;
}
.confidence-high {
color: #28a745;
font-weight: bold;
}
.confidence-medium {
color: #ffc107;
font-weight: bold;
}
.confidence-low {
color: #dc3545;
font-weight: bold;
}
</style>
""",
unsafe_allow_html=True,
)
def check_api_health() -> bool:
"""Check if the API is running and healthy."""
try:
response = requests.get(f"{API_BASE_URL}/health", timeout=2)
return response.status_code == 200
except Exception:
return False
def predict_skills(
issue_text: str, issue_description: str = None, repo_name: str = None, pr_number: int = None
) -> Dict:
"""Call the prediction API."""
payload = {"issue_text": issue_text}
if issue_description:
payload["issue_description"] = issue_description
if repo_name:
payload["repo_name"] = repo_name
if pr_number:
payload["pr_number"] = pr_number
try:
response = requests.post(f"{API_BASE_URL}/predict", json=payload, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
st.error(f"API Error: {str(e)}")
return None
def display_predictions(predictions: List[Dict], threshold: float = 0.5):
"""Display predictions with visual formatting."""
# Filter by threshold
filtered = [p for p in predictions if p["confidence"] >= threshold]
if not filtered:
st.warning(f"No predictions above confidence threshold {threshold:.2f}")
return
st.success(f"Found {len(filtered)} skills above threshold {threshold:.2f}")
# Create DataFrame for table view
df = pd.DataFrame(filtered)
df["confidence"] = df["confidence"].apply(lambda x: f"{x:.2%}")
col1, col2 = st.columns([2, 1])
with col1:
st.subheader("Predictions Table")
st.dataframe(
df,
use_container_width=True,
hide_index=True,
column_config={
"skill_name": st.column_config.TextColumn("Skill", width="large"),
"confidence": st.column_config.TextColumn("Confidence", width="medium"),
},
)
with col2:
st.subheader("Top 5 Skills")
for i, pred in enumerate(filtered[:5], 1):
confidence = pred["confidence"]
if confidence >= 0.8:
conf_class = "confidence-high"
elif confidence >= 0.5:
conf_class = "confidence-medium"
else:
conf_class = "confidence-low"
st.markdown(
f"""
<div class="skill-card">
<strong>#{i} {pred["skill_name"]}</strong><br>
<span class="{conf_class}">{confidence:.2%}</span>
</div>
""",
unsafe_allow_html=True,
)
def main():
"""Main Streamlit app."""
if "example_text" not in st.session_state:
st.session_state.example_text = ""
# Header
st.markdown('<h1 class="main-header"> GitHub Skill Classifier</h1>', unsafe_allow_html=True)
st.markdown("""
This tool uses machine learning to predict the skills required for GitHub issues and pull requests.
Enter the issue text below to get started!
""")
# Sidebar
with st.sidebar:
st.header("Settings")
# API Status
st.subheader("API Status")
if check_api_health():
st.success(" API is running")
else:
st.error(" API is not available")
st.info(f"Make sure FastAPI is running at {API_BASE_URL}")
st.code("fastapi dev hopcroft_skill_classification_tool_competition/main.py")
# Confidence threshold
threshold = st.slider(
"Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.05,
help="Only show predictions above this confidence level",
)
# Model info
st.subheader("Model Info")
try:
health = requests.get(f"{API_BASE_URL}/health", timeout=2).json()
st.metric("Version", health.get("version", "N/A"))
st.metric("Model Loaded", "" if health.get("model_loaded") else "")
except Exception:
st.info("API not available")
# Main
st.header("Input")
# Tabs for different input modes
tab1, tab2, tab3 = st.tabs(["Quick Input", "Detailed Input", "Examples"])
with tab1:
issue_text = st.text_area(
"Issue/PR Text",
height=150,
placeholder="Enter the issue or pull request text here...",
help="Required: The main text of the GitHub issue or PR",
value=st.session_state.example_text,
)
if st.button("Predict Skills", type="primary", use_container_width=True):
if not issue_text.strip():
st.error("Please enter some text!")
else:
st.session_state.example_text = ""
with st.spinner("Analyzing issue..."):
result = predict_skills(issue_text)
if result:
st.header("Results")
# Metadata
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Predictions", result.get("num_predictions", 0))
with col2:
st.metric(
"Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms"
)
with col3:
st.metric("Model Version", result.get("model_version", "N/A"))
# Predictions
st.divider()
display_predictions(result.get("predictions", []), threshold)
# Raw JSON
with st.expander("🔍 View Raw Response"):
st.json(result)
with tab2:
col1, col2 = st.columns(2)
with col1:
issue_text_detailed = st.text_area(
"Issue Title/Text*",
height=100,
placeholder="e.g., Fix authentication bug in login module",
key="issue_text_detailed",
)
issue_description = st.text_area(
"Issue Description",
height=100,
placeholder="Optional: Detailed description of the issue",
key="issue_description",
)
with col2:
repo_name = st.text_input(
"Repository Name",
placeholder="e.g., owner/repository",
help="Optional: GitHub repository name",
)
pr_number = st.number_input(
"PR Number",
min_value=0,
value=0,
help="Optional: Pull request number (0 = not a PR)",
)
if st.button("Predict Skills (Detailed)", type="primary", use_container_width=True):
if not issue_text_detailed.strip():
st.error("Issue text is required!")
else:
with st.spinner("Analyzing issue..."):
result = predict_skills(
issue_text_detailed,
issue_description if issue_description else None,
repo_name if repo_name else None,
pr_number if pr_number > 0 else None,
)
if result:
st.header("Results")
# Metadata
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Total Predictions", result.get("num_predictions", 0))
with col2:
st.metric(
"Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms"
)
with col3:
st.metric("Model Version", result.get("model_version", "N/A"))
st.divider()
display_predictions(result.get("predictions", []), threshold)
with st.expander("🔍 View Raw Response"):
st.json(result)
with tab3:
st.markdown("### Example Issues")
examples = [
{
"title": "Authentication Bug",
"text": "Fix authentication bug in login module. Users cannot login with OAuth providers.",
},
{
"title": "Machine Learning Feature",
"text": "Implement transfer learning with transformers for text classification using PyTorch and TensorFlow.",
},
{
"title": "Database Issue",
"text": "Fix database connection pooling issue causing memory leaks in production environment.",
},
{
"title": "UI Enhancement",
"text": "Add responsive design support for mobile devices with CSS media queries and flexbox layout.",
},
]
for i, example in enumerate(examples):
if st.button(example["title"], use_container_width=True, key=f"example_btn_{i}"):
st.session_state.example_text = example["text"]
st.rerun()
if st.session_state.example_text:
st.success(" Example loaded! Switch to 'Quick Input' tab to use it.")
with st.expander("Preview"):
st.code(st.session_state.example_text)
if __name__ == "__main__":
main()