File size: 10,540 Bytes
225af6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import os
from typing import Dict, List

import pandas as pd
import requests
import streamlit as st

API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:8000")

# Page config
st.set_page_config(
    page_title="GitHub Skill Classifier", layout="wide", initial_sidebar_state="expanded"
)

st.markdown(
    """
<style>
    .main-header {
        font-size: 2.5rem;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .skill-card {
        padding: 1rem;
        border-radius: 0.5rem;
        border-left: 4px solid #1f77b4;
        background-color: #f0f2f6;
        margin-bottom: 0.5rem;
    }
    .confidence-high {
        color: #28a745;
        font-weight: bold;
    }
    .confidence-medium {
        color: #ffc107;
        font-weight: bold;
    }
    .confidence-low {
        color: #dc3545;
        font-weight: bold;
    }
</style>
""",
    unsafe_allow_html=True,
)


def check_api_health() -> bool:
    """Check if the API is running and healthy."""
    try:
        response = requests.get(f"{API_BASE_URL}/health", timeout=2)
        return response.status_code == 200
    except Exception:
        return False


def predict_skills(
    issue_text: str, issue_description: str = None, repo_name: str = None, pr_number: int = None
) -> Dict:
    """Call the prediction API."""
    payload = {"issue_text": issue_text}

    if issue_description:
        payload["issue_description"] = issue_description
    if repo_name:
        payload["repo_name"] = repo_name
    if pr_number:
        payload["pr_number"] = pr_number

    try:
        response = requests.post(f"{API_BASE_URL}/predict", json=payload, timeout=30)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        st.error(f"API Error: {str(e)}")
        return None


def display_predictions(predictions: List[Dict], threshold: float = 0.5):
    """Display predictions with visual formatting."""

    # Filter by threshold
    filtered = [p for p in predictions if p["confidence"] >= threshold]

    if not filtered:
        st.warning(f"No predictions above confidence threshold {threshold:.2f}")
        return

    st.success(f"Found {len(filtered)} skills above threshold {threshold:.2f}")

    # Create DataFrame for table view
    df = pd.DataFrame(filtered)
    df["confidence"] = df["confidence"].apply(lambda x: f"{x:.2%}")

    col1, col2 = st.columns([2, 1])

    with col1:
        st.subheader("Predictions Table")
        st.dataframe(
            df,
            use_container_width=True,
            hide_index=True,
            column_config={
                "skill_name": st.column_config.TextColumn("Skill", width="large"),
                "confidence": st.column_config.TextColumn("Confidence", width="medium"),
            },
        )

    with col2:
        st.subheader("Top 5 Skills")
        for i, pred in enumerate(filtered[:5], 1):
            confidence = pred["confidence"]

            if confidence >= 0.8:
                conf_class = "confidence-high"
            elif confidence >= 0.5:
                conf_class = "confidence-medium"
            else:
                conf_class = "confidence-low"

            st.markdown(
                f"""
            <div class="skill-card">
                <strong>#{i} {pred["skill_name"]}</strong><br>
                <span class="{conf_class}">{confidence:.2%}</span>
            </div>
            """,
                unsafe_allow_html=True,
            )


def main():
    """Main Streamlit app."""

    if "example_text" not in st.session_state:
        st.session_state.example_text = ""

    # Header
    st.markdown('<h1 class="main-header"> GitHub Skill Classifier</h1>', unsafe_allow_html=True)

    st.markdown("""
    This tool uses machine learning to predict the skills required for GitHub issues and pull requests.
    Enter the issue text below to get started!
    """)

    # Sidebar
    with st.sidebar:
        st.header("Settings")

        # API Status
        st.subheader("API Status")
        if check_api_health():
            st.success(" API is running")
        else:
            st.error(" API is not available")
            st.info(f"Make sure FastAPI is running at {API_BASE_URL}")
            st.code("fastapi dev hopcroft_skill_classification_tool_competition/main.py")

        # Confidence threshold
        threshold = st.slider(
            "Confidence Threshold",
            min_value=0.0,
            max_value=1.0,
            value=0.5,
            step=0.05,
            help="Only show predictions above this confidence level",
        )

        # Model info
        st.subheader("Model Info")
        try:
            health = requests.get(f"{API_BASE_URL}/health", timeout=2).json()
            st.metric("Version", health.get("version", "N/A"))
            st.metric("Model Loaded", "" if health.get("model_loaded") else "")
        except Exception:
            st.info("API not available")

    # Main
    st.header("Input")

    # Tabs for different input modes
    tab1, tab2, tab3 = st.tabs(["Quick Input", "Detailed Input", "Examples"])

    with tab1:
        issue_text = st.text_area(
            "Issue/PR Text",
            height=150,
            placeholder="Enter the issue or pull request text here...",
            help="Required: The main text of the GitHub issue or PR",
            value=st.session_state.example_text,
        )

        if st.button("Predict Skills", type="primary", use_container_width=True):
            if not issue_text.strip():
                st.error("Please enter some text!")
            else:
                st.session_state.example_text = ""
                with st.spinner("Analyzing issue..."):
                    result = predict_skills(issue_text)

                    if result:
                        st.header("Results")

                        # Metadata
                        col1, col2, col3 = st.columns(3)
                        with col1:
                            st.metric("Total Predictions", result.get("num_predictions", 0))
                        with col2:
                            st.metric(
                                "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms"
                            )
                        with col3:
                            st.metric("Model Version", result.get("model_version", "N/A"))

                        # Predictions
                        st.divider()
                        display_predictions(result.get("predictions", []), threshold)

                        # Raw JSON
                        with st.expander("🔍 View Raw Response"):
                            st.json(result)

    with tab2:
        col1, col2 = st.columns(2)

        with col1:
            issue_text_detailed = st.text_area(
                "Issue Title/Text*",
                height=100,
                placeholder="e.g., Fix authentication bug in login module",
                key="issue_text_detailed",
            )

            issue_description = st.text_area(
                "Issue Description",
                height=100,
                placeholder="Optional: Detailed description of the issue",
                key="issue_description",
            )

        with col2:
            repo_name = st.text_input(
                "Repository Name",
                placeholder="e.g., owner/repository",
                help="Optional: GitHub repository name",
            )

            pr_number = st.number_input(
                "PR Number",
                min_value=0,
                value=0,
                help="Optional: Pull request number (0 = not a PR)",
            )

        if st.button("Predict Skills (Detailed)", type="primary", use_container_width=True):
            if not issue_text_detailed.strip():
                st.error("Issue text is required!")
            else:
                with st.spinner("Analyzing issue..."):
                    result = predict_skills(
                        issue_text_detailed,
                        issue_description if issue_description else None,
                        repo_name if repo_name else None,
                        pr_number if pr_number > 0 else None,
                    )

                    if result:
                        st.header("Results")

                        # Metadata
                        col1, col2, col3 = st.columns(3)
                        with col1:
                            st.metric("Total Predictions", result.get("num_predictions", 0))
                        with col2:
                            st.metric(
                                "Processing Time", f"{result.get('processing_time_ms', 0):.2f} ms"
                            )
                        with col3:
                            st.metric("Model Version", result.get("model_version", "N/A"))

                        st.divider()
                        display_predictions(result.get("predictions", []), threshold)

                        with st.expander("🔍 View Raw Response"):
                            st.json(result)

    with tab3:
        st.markdown("### Example Issues")

        examples = [
            {
                "title": "Authentication Bug",
                "text": "Fix authentication bug in login module. Users cannot login with OAuth providers.",
            },
            {
                "title": "Machine Learning Feature",
                "text": "Implement transfer learning with transformers for text classification using PyTorch and TensorFlow.",
            },
            {
                "title": "Database Issue",
                "text": "Fix database connection pooling issue causing memory leaks in production environment.",
            },
            {
                "title": "UI Enhancement",
                "text": "Add responsive design support for mobile devices with CSS media queries and flexbox layout.",
            },
        ]

        for i, example in enumerate(examples):
            if st.button(example["title"], use_container_width=True, key=f"example_btn_{i}"):
                st.session_state.example_text = example["text"]
                st.rerun()

        if st.session_state.example_text:
            st.success(" Example loaded! Switch to 'Quick Input' tab to use it.")
            with st.expander("Preview"):
                st.code(st.session_state.example_text)


if __name__ == "__main__":
    main()