|
|
import os |
|
|
from typing import Dict, List |
|
|
|
|
|
import pandas as pd |
|
|
import requests |
|
|
import streamlit as st |
|
|
|
|
|
API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="GitHub Skill Classifier", layout="wide", initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
.main-header { |
|
|
font-size: 2.5rem; |
|
|
color: #1f77b4; |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.skill-card { |
|
|
padding: 1rem; |
|
|
border-radius: 0.5rem; |
|
|
border-left: 4px solid #1f77b4; |
|
|
background-color: #f0f2f6; |
|
|
margin-bottom: 0.5rem; |
|
|
} |
|
|
.confidence-high { |
|
|
color: #28a745; |
|
|
font-weight: bold; |
|
|
} |
|
|
.confidence-medium { |
|
|
color: #ffc107; |
|
|
font-weight: bold; |
|
|
} |
|
|
.confidence-low { |
|
|
color: #dc3545; |
|
|
font-weight: bold; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
def check_api_health() -> bool: |
|
|
"""Check if the API is running and healthy.""" |
|
|
try: |
|
|
response = requests.get(f"{API_BASE_URL}/health", timeout=2) |
|
|
return response.status_code == 200 |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def predict_skills( |
|
|
issue_text: str, issue_description: str = None, repo_name: str = None, pr_number: int = None |
|
|
) -> Dict: |
|
|
"""Call the prediction API.""" |
|
|
payload = {"issue_text": issue_text} |
|
|
|
|
|
if issue_description: |
|
|
payload["issue_description"] = issue_description |
|
|
if repo_name: |
|
|
payload["repo_name"] = repo_name |
|
|
if pr_number: |
|
|
payload["pr_number"] = pr_number |
|
|
|
|
|
try: |
|
|
response = requests.post(f"{API_BASE_URL}/predict", json=payload, timeout=30) |
|
|
response.raise_for_status() |
|
|
return response.json() |
|
|
except requests.exceptions.RequestException as e: |
|
|
st.error(f"API Error: {str(e)}") |
|
|
return None |
|
|
|
|
|
|
|
|
def display_predictions(predictions: List[Dict], threshold: float = 0.5): |
|
|
"""Display predictions with visual formatting.""" |
|
|
|
|
|
|
|
|
filtered = [p for p in predictions if p["confidence"] >= threshold] |
|
|
|
|
|
if not filtered: |
|
|
st.warning(f"No predictions above confidence threshold {threshold:.2f}") |
|
|
return |
|
|
|
|
|
st.success(f"Found {len(filtered)} skills above threshold {threshold:.2f}") |
|
|
|
|
|
|
|
|
df = pd.DataFrame(filtered) |
|
|
df["confidence"] = df["confidence"].apply(lambda x: f"{x:.2%}") |
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Predictions Table") |
|
|
st.dataframe( |
|
|
df, |
|
|
use_container_width=True, |
|
|
hide_index=True, |
|
|
column_config={ |
|
|
"skill_name": st.column_config.TextColumn("Skill", width="large"), |
|
|
"confidence": st.column_config.TextColumn("Confidence", width="medium"), |
|
|
}, |
|
|
) |
|
|
|
|
|
with col2: |
|
|
st.subheader("Top 5 Skills") |
|
|
for i, pred in enumerate(filtered[:5], 1): |
|
|
confidence = pred["confidence"] |
|
|
|
|
|
if confidence >= 0.8: |
|
|
conf_class = "confidence-high" |
|
|
elif confidence >= 0.5: |
|
|
conf_class = "confidence-medium" |
|
|
else: |
|
|
conf_class = "confidence-low" |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div class="skill-card"> |
|
|
<strong>#{i} {pred["skill_name"]}</strong><br> |
|
|
<span class="{conf_class}">{confidence:.2%}</span> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main Streamlit app.""" |
|
|
|
|
|
if "example_text" not in st.session_state: |
|
|
st.session_state.example_text = "" |
|
|
|
|
|
|
|
|
st.markdown('<h1 class="main-header"> GitHub Skill Classifier</h1>', unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(""" |
|
|
This tool uses machine learning to predict the skills required for GitHub issues and pull requests. |
|
|
Enter the issue text below to get started! |
|
|
""") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Settings") |
|
|
|
|
|
|
|
|
st.subheader("API Status") |
|
|
if check_api_health(): |
|
|
st.success(" API is running") |
|
|
else: |
|
|
st.error(" API is not available") |
|
|
st.info(f"Make sure FastAPI is running at {API_BASE_URL}") |
|
|
st.code("fastapi dev hopcroft_skill_classification_tool_competition/main.py") |
|
|
|
|
|
|
|
|
threshold = st.slider( |
|
|
"Confidence Threshold", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
value=0.5, |
|
|
step=0.05, |
|
|
help="Only show predictions above this confidence level", |
|
|
) |
|
|
|
|
|
|
|
|
st.subheader("Model Info") |
|
|
try: |
|
|
health = requests.get(f"{API_BASE_URL}/health", timeout=2).json() |
|
|
st.metric("Version", health.get("version", "N/A")) |
|
|
st.metric("Model Loaded", "" if health.get("model_loaded") else "") |
|
|
except Exception: |
|
|
st.info("API not available") |
|
|
|
|
|
|
|
|
st.header("Input") |
|
|
|
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["Quick Input", "Detailed Input", "Examples"]) |
|
|
|
|
|
with tab1: |
|
|
issue_text = st.text_area( |
|
|
"Issue/PR Text", |
|
|
height=150, |
|
|
placeholder="Enter the issue or pull request text here...", |
|
|
help="Required: The main text of the GitHub issue or PR", |
|
|
value=st.session_state.example_text, |
|
|
) |
|
|
|
|
|
if st.button("Predict Skills", type="primary", use_container_width=True): |
|
|
if not issue_text.strip(): |
|
|
st.error("Please enter some text!") |
|
|
else: |
|
|
st.session_state.example_text = "" |
|
|
with st.spinner("Analyzing issue..."): |
|
|
result = predict_skills(issue_text) |
|
|
|
|
|
if result: |
|
|
st.header("Results") |
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
st.metric("Total Predictions", result.get("num_predictions", 0)) |
|
|
with col2: |
|
|
st.metric( |
|
|
"Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" |
|
|
) |
|
|
with col3: |
|
|
st.metric("Model Version", result.get("model_version", "N/A")) |
|
|
|
|
|
|
|
|
st.divider() |
|
|
display_predictions(result.get("predictions", []), threshold) |
|
|
|
|
|
|
|
|
with st.expander("🔍 View Raw Response"): |
|
|
st.json(result) |
|
|
|
|
|
with tab2: |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
issue_text_detailed = st.text_area( |
|
|
"Issue Title/Text*", |
|
|
height=100, |
|
|
placeholder="e.g., Fix authentication bug in login module", |
|
|
key="issue_text_detailed", |
|
|
) |
|
|
|
|
|
issue_description = st.text_area( |
|
|
"Issue Description", |
|
|
height=100, |
|
|
placeholder="Optional: Detailed description of the issue", |
|
|
key="issue_description", |
|
|
) |
|
|
|
|
|
with col2: |
|
|
repo_name = st.text_input( |
|
|
"Repository Name", |
|
|
placeholder="e.g., owner/repository", |
|
|
help="Optional: GitHub repository name", |
|
|
) |
|
|
|
|
|
pr_number = st.number_input( |
|
|
"PR Number", |
|
|
min_value=0, |
|
|
value=0, |
|
|
help="Optional: Pull request number (0 = not a PR)", |
|
|
) |
|
|
|
|
|
if st.button("Predict Skills (Detailed)", type="primary", use_container_width=True): |
|
|
if not issue_text_detailed.strip(): |
|
|
st.error("Issue text is required!") |
|
|
else: |
|
|
with st.spinner("Analyzing issue..."): |
|
|
result = predict_skills( |
|
|
issue_text_detailed, |
|
|
issue_description if issue_description else None, |
|
|
repo_name if repo_name else None, |
|
|
pr_number if pr_number > 0 else None, |
|
|
) |
|
|
|
|
|
if result: |
|
|
st.header("Results") |
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
with col1: |
|
|
st.metric("Total Predictions", result.get("num_predictions", 0)) |
|
|
with col2: |
|
|
st.metric( |
|
|
"Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" |
|
|
) |
|
|
with col3: |
|
|
st.metric("Model Version", result.get("model_version", "N/A")) |
|
|
|
|
|
st.divider() |
|
|
display_predictions(result.get("predictions", []), threshold) |
|
|
|
|
|
with st.expander("🔍 View Raw Response"): |
|
|
st.json(result) |
|
|
|
|
|
with tab3: |
|
|
st.markdown("### Example Issues") |
|
|
|
|
|
examples = [ |
|
|
{ |
|
|
"title": "Authentication Bug", |
|
|
"text": "Fix authentication bug in login module. Users cannot login with OAuth providers.", |
|
|
}, |
|
|
{ |
|
|
"title": "Machine Learning Feature", |
|
|
"text": "Implement transfer learning with transformers for text classification using PyTorch and TensorFlow.", |
|
|
}, |
|
|
{ |
|
|
"title": "Database Issue", |
|
|
"text": "Fix database connection pooling issue causing memory leaks in production environment.", |
|
|
}, |
|
|
{ |
|
|
"title": "UI Enhancement", |
|
|
"text": "Add responsive design support for mobile devices with CSS media queries and flexbox layout.", |
|
|
}, |
|
|
] |
|
|
|
|
|
for i, example in enumerate(examples): |
|
|
if st.button(example["title"], use_container_width=True, key=f"example_btn_{i}"): |
|
|
st.session_state.example_text = example["text"] |
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.example_text: |
|
|
st.success(" Example loaded! Switch to 'Quick Input' tab to use it.") |
|
|
with st.expander("Preview"): |
|
|
st.code(st.session_state.example_text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|