import os from typing import Dict, List import pandas as pd import requests import streamlit as st API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") # Page config st.set_page_config( page_title="GitHub Skill Classifier", layout="wide", initial_sidebar_state="expanded" ) st.markdown( """ """, unsafe_allow_html=True, ) def check_api_health() -> bool: """Check if the API is running and healthy.""" try: response = requests.get(f"{API_BASE_URL}/health", timeout=2) return response.status_code == 200 except Exception: return False def predict_skills( issue_text: str, issue_description: str = None, repo_name: str = None, pr_number: int = None ) -> Dict: """Call the prediction API.""" payload = {"issue_text": issue_text} if issue_description: payload["issue_description"] = issue_description if repo_name: payload["repo_name"] = repo_name if pr_number: payload["pr_number"] = pr_number try: response = requests.post(f"{API_BASE_URL}/predict", json=payload, timeout=30) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: st.error(f"API Error: {str(e)}") return None def display_predictions(predictions: List[Dict], threshold: float = 0.5): """Display predictions with visual formatting.""" # Filter by threshold filtered = [p for p in predictions if p["confidence"] >= threshold] if not filtered: st.warning(f"No predictions above confidence threshold {threshold:.2f}") return st.success(f"Found {len(filtered)} skills above threshold {threshold:.2f}") # Create DataFrame for table view df = pd.DataFrame(filtered) df["confidence"] = df["confidence"].apply(lambda x: f"{x:.2%}") col1, col2 = st.columns([2, 1]) with col1: st.subheader("Predictions Table") st.dataframe( df, use_container_width=True, hide_index=True, column_config={ "skill_name": st.column_config.TextColumn("Skill", width="large"), "confidence": st.column_config.TextColumn("Confidence", width="medium"), }, ) with col2: st.subheader("Top 5 Skills") for i, pred in enumerate(filtered[:5], 1): confidence = pred["confidence"] if confidence >= 0.8: conf_class = "confidence-high" elif confidence >= 0.5: conf_class = "confidence-medium" else: conf_class = "confidence-low" st.markdown( f"""
#{i} {pred["skill_name"]}
{confidence:.2%}
""", unsafe_allow_html=True, ) def main(): """Main Streamlit app.""" if "example_text" not in st.session_state: st.session_state.example_text = "" # Header st.markdown('

GitHub Skill Classifier

', unsafe_allow_html=True) st.markdown(""" This tool uses machine learning to predict the skills required for GitHub issues and pull requests. Enter the issue text below to get started! """) # Sidebar with st.sidebar: st.header("Settings") # API Status st.subheader("API Status") if check_api_health(): st.success(" API is running") else: st.error(" API is not available") st.info(f"Make sure FastAPI is running at {API_BASE_URL}") st.code("fastapi dev hopcroft_skill_classification_tool_competition/main.py") # Confidence threshold threshold = st.slider( "Confidence Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05, help="Only show predictions above this confidence level", ) # Model info st.subheader("Model Info") try: health = requests.get(f"{API_BASE_URL}/health", timeout=2).json() st.metric("Version", health.get("version", "N/A")) st.metric("Model Loaded", "" if health.get("model_loaded") else "") except Exception: st.info("API not available") # Main st.header("Input") # Tabs for different input modes tab1, tab2, tab3 = st.tabs(["Quick Input", "Detailed Input", "Examples"]) with tab1: issue_text = st.text_area( "Issue/PR Text", height=150, placeholder="Enter the issue or pull request text here...", help="Required: The main text of the GitHub issue or PR", value=st.session_state.example_text, ) if st.button("Predict Skills", type="primary", use_container_width=True): if not issue_text.strip(): st.error("Please enter some text!") else: st.session_state.example_text = "" with st.spinner("Analyzing issue..."): result = predict_skills(issue_text) if result: st.header("Results") # Metadata col1, col2, col3 = st.columns(3) with col1: st.metric("Total Predictions", result.get("num_predictions", 0)) with col2: st.metric( "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" ) with col3: st.metric("Model Version", result.get("model_version", "N/A")) # Predictions st.divider() display_predictions(result.get("predictions", []), threshold) # Raw JSON with st.expander("🔍 View Raw Response"): st.json(result) with tab2: col1, col2 = st.columns(2) with col1: issue_text_detailed = st.text_area( "Issue Title/Text*", height=100, placeholder="e.g., Fix authentication bug in login module", key="issue_text_detailed", ) issue_description = st.text_area( "Issue Description", height=100, placeholder="Optional: Detailed description of the issue", key="issue_description", ) with col2: repo_name = st.text_input( "Repository Name", placeholder="e.g., owner/repository", help="Optional: GitHub repository name", ) pr_number = st.number_input( "PR Number", min_value=0, value=0, help="Optional: Pull request number (0 = not a PR)", ) if st.button("Predict Skills (Detailed)", type="primary", use_container_width=True): if not issue_text_detailed.strip(): st.error("Issue text is required!") else: with st.spinner("Analyzing issue..."): result = predict_skills( issue_text_detailed, issue_description if issue_description else None, repo_name if repo_name else None, pr_number if pr_number > 0 else None, ) if result: st.header("Results") # Metadata col1, col2, col3 = st.columns(3) with col1: st.metric("Total Predictions", result.get("num_predictions", 0)) with col2: st.metric( "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" ) with col3: st.metric("Model Version", result.get("model_version", "N/A")) st.divider() display_predictions(result.get("predictions", []), threshold) with st.expander("🔍 View Raw Response"): st.json(result) with tab3: st.markdown("### Example Issues") examples = [ { "title": "Authentication Bug", "text": "Fix authentication bug in login module. Users cannot login with OAuth providers.", }, { "title": "Machine Learning Feature", "text": "Implement transfer learning with transformers for text classification using PyTorch and TensorFlow.", }, { "title": "Database Issue", "text": "Fix database connection pooling issue causing memory leaks in production environment.", }, { "title": "UI Enhancement", "text": "Add responsive design support for mobile devices with CSS media queries and flexbox layout.", }, ] for i, example in enumerate(examples): if st.button(example["title"], use_container_width=True, key=f"example_btn_{i}"): st.session_state.example_text = example["text"] st.rerun() if st.session_state.example_text: st.success(" Example loaded! Switch to 'Quick Input' tab to use it.") with st.expander("Preview"): st.code(st.session_state.example_text) if __name__ == "__main__": main()