Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List | |
| import pandas as pd | |
| import requests | |
| import streamlit as st | |
| API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000") | |
| # Page config | |
| st.set_page_config( | |
| page_title="GitHub Skill Classifier", layout="wide", initial_sidebar_state="expanded" | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem; | |
| color: #1f77b4; | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .skill-card { | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| border-left: 4px solid #1f77b4; | |
| background-color: #f0f2f6; | |
| margin-bottom: 0.5rem; | |
| } | |
| .confidence-high { | |
| color: #28a745; | |
| font-weight: bold; | |
| } | |
| .confidence-medium { | |
| color: #ffc107; | |
| font-weight: bold; | |
| } | |
| .confidence-low { | |
| color: #dc3545; | |
| font-weight: bold; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def check_api_health() -> bool: | |
| """Check if the API is running and healthy.""" | |
| try: | |
| response = requests.get(f"{API_BASE_URL}/health", timeout=2) | |
| return response.status_code == 200 | |
| except Exception: | |
| return False | |
| def predict_skills( | |
| issue_text: str, issue_description: str = None, repo_name: str = None, pr_number: int = None | |
| ) -> Dict: | |
| """Call the prediction API.""" | |
| payload = {"issue_text": issue_text} | |
| if issue_description: | |
| payload["issue_description"] = issue_description | |
| if repo_name: | |
| payload["repo_name"] = repo_name | |
| if pr_number: | |
| payload["pr_number"] = pr_number | |
| try: | |
| response = requests.post(f"{API_BASE_URL}/predict", json=payload, timeout=30) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"API Error: {str(e)}") | |
| return None | |
| def display_predictions(predictions: List[Dict], threshold: float = 0.5): | |
| """Display predictions with visual formatting.""" | |
| # Filter by threshold | |
| filtered = [p for p in predictions if p["confidence"] >= threshold] | |
| if not filtered: | |
| st.warning(f"No predictions above confidence threshold {threshold:.2f}") | |
| return | |
| st.success(f"Found {len(filtered)} skills above threshold {threshold:.2f}") | |
| # Create DataFrame for table view | |
| df = pd.DataFrame(filtered) | |
| df["confidence"] = df["confidence"].apply(lambda x: f"{x:.2%}") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader("Predictions Table") | |
| st.dataframe( | |
| df, | |
| use_container_width=True, | |
| hide_index=True, | |
| column_config={ | |
| "skill_name": st.column_config.TextColumn("Skill", width="large"), | |
| "confidence": st.column_config.TextColumn("Confidence", width="medium"), | |
| }, | |
| ) | |
| with col2: | |
| st.subheader("Top 5 Skills") | |
| for i, pred in enumerate(filtered[:5], 1): | |
| confidence = pred["confidence"] | |
| if confidence >= 0.8: | |
| conf_class = "confidence-high" | |
| elif confidence >= 0.5: | |
| conf_class = "confidence-medium" | |
| else: | |
| conf_class = "confidence-low" | |
| st.markdown( | |
| f""" | |
| <div class="skill-card"> | |
| <strong>#{i} {pred["skill_name"]}</strong><br> | |
| <span class="{conf_class}">{confidence:.2%}</span> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def main(): | |
| """Main Streamlit app.""" | |
| if "example_text" not in st.session_state: | |
| st.session_state.example_text = "" | |
| # Header | |
| st.markdown('<h1 class="main-header"> GitHub Skill Classifier</h1>', unsafe_allow_html=True) | |
| st.markdown(""" | |
| This tool uses machine learning to predict the skills required for GitHub issues and pull requests. | |
| Enter the issue text below to get started! | |
| """) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("Settings") | |
| # API Status | |
| st.subheader("API Status") | |
| if check_api_health(): | |
| st.success(" API is running") | |
| else: | |
| st.error(" API is not available") | |
| st.info(f"Make sure FastAPI is running at {API_BASE_URL}") | |
| st.code("fastapi dev hopcroft_skill_classification_tool_competition/main.py") | |
| # Confidence threshold | |
| threshold = st.slider( | |
| "Confidence Threshold", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.05, | |
| help="Only show predictions above this confidence level", | |
| ) | |
| # Model info | |
| st.subheader("Model Info") | |
| try: | |
| health = requests.get(f"{API_BASE_URL}/health", timeout=2).json() | |
| st.metric("Version", health.get("version", "N/A")) | |
| st.metric("Model Loaded", "" if health.get("model_loaded") else "") | |
| except Exception: | |
| st.info("API not available") | |
| # Main | |
| st.header("Input") | |
| # Tabs for different input modes | |
| tab1, tab2, tab3 = st.tabs(["Quick Input", "Detailed Input", "Examples"]) | |
| with tab1: | |
| issue_text = st.text_area( | |
| "Issue/PR Text", | |
| height=150, | |
| placeholder="Enter the issue or pull request text here...", | |
| help="Required: The main text of the GitHub issue or PR", | |
| value=st.session_state.example_text, | |
| ) | |
| if st.button("Predict Skills", type="primary", use_container_width=True): | |
| if not issue_text.strip(): | |
| st.error("Please enter some text!") | |
| else: | |
| st.session_state.example_text = "" | |
| with st.spinner("Analyzing issue..."): | |
| result = predict_skills(issue_text) | |
| if result: | |
| st.header("Results") | |
| # Metadata | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Total Predictions", result.get("num_predictions", 0)) | |
| with col2: | |
| st.metric( | |
| "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" | |
| ) | |
| with col3: | |
| st.metric("Model Version", result.get("model_version", "N/A")) | |
| # Predictions | |
| st.divider() | |
| display_predictions(result.get("predictions", []), threshold) | |
| # Raw JSON | |
| with st.expander("🔍 View Raw Response"): | |
| st.json(result) | |
| with tab2: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| issue_text_detailed = st.text_area( | |
| "Issue Title/Text*", | |
| height=100, | |
| placeholder="e.g., Fix authentication bug in login module", | |
| key="issue_text_detailed", | |
| ) | |
| issue_description = st.text_area( | |
| "Issue Description", | |
| height=100, | |
| placeholder="Optional: Detailed description of the issue", | |
| key="issue_description", | |
| ) | |
| with col2: | |
| repo_name = st.text_input( | |
| "Repository Name", | |
| placeholder="e.g., owner/repository", | |
| help="Optional: GitHub repository name", | |
| ) | |
| pr_number = st.number_input( | |
| "PR Number", | |
| min_value=0, | |
| value=0, | |
| help="Optional: Pull request number (0 = not a PR)", | |
| ) | |
| if st.button("Predict Skills (Detailed)", type="primary", use_container_width=True): | |
| if not issue_text_detailed.strip(): | |
| st.error("Issue text is required!") | |
| else: | |
| with st.spinner("Analyzing issue..."): | |
| result = predict_skills( | |
| issue_text_detailed, | |
| issue_description if issue_description else None, | |
| repo_name if repo_name else None, | |
| pr_number if pr_number > 0 else None, | |
| ) | |
| if result: | |
| st.header("Results") | |
| # Metadata | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Total Predictions", result.get("num_predictions", 0)) | |
| with col2: | |
| st.metric( | |
| "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms" | |
| ) | |
| with col3: | |
| st.metric("Model Version", result.get("model_version", "N/A")) | |
| st.divider() | |
| display_predictions(result.get("predictions", []), threshold) | |
| with st.expander("🔍 View Raw Response"): | |
| st.json(result) | |
| with tab3: | |
| st.markdown("### Example Issues") | |
| examples = [ | |
| { | |
| "title": "Authentication Bug", | |
| "text": "Fix authentication bug in login module. Users cannot login with OAuth providers.", | |
| }, | |
| { | |
| "title": "Machine Learning Feature", | |
| "text": "Implement transfer learning with transformers for text classification using PyTorch and TensorFlow.", | |
| }, | |
| { | |
| "title": "Database Issue", | |
| "text": "Fix database connection pooling issue causing memory leaks in production environment.", | |
| }, | |
| { | |
| "title": "UI Enhancement", | |
| "text": "Add responsive design support for mobile devices with CSS media queries and flexbox layout.", | |
| }, | |
| ] | |
| for i, example in enumerate(examples): | |
| if st.button(example["title"], use_container_width=True, key=f"example_btn_{i}"): | |
| st.session_state.example_text = example["text"] | |
| st.rerun() | |
| if st.session_state.example_text: | |
| st.success(" Example loaded! Switch to 'Quick Input' tab to use it.") | |
| with st.expander("Preview"): | |
| st.code(st.session_state.example_text) | |
| if __name__ == "__main__": | |
| main() | |