File size: 2,458 Bytes
8464aea
9b2cded
 
 
 
 
8464aea
9b2cded
8464aea
 
 
 
 
 
 
 
9b2cded
 
 
 
8464aea
 
9b2cded
 
 
8464aea
 
 
9b2cded
 
 
8464aea
 
 
9b2cded
 
8464aea
 
9b2cded
8464aea
 
9b2cded
8464aea
9b2cded
 
 
8464aea
9b2cded
 
8464aea
 
 
 
 
 
9b2cded
8464aea
 
9b2cded
8464aea
9b2cded
8464aea
 
9b2cded
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""CodeBERT inference Gradio app (optional HF Space for predictions)."""

from __future__ import annotations

import gradio as gr

from src.hf_predict_codebert import CodeBERTSQLErrorClassifier

MODEL_DIR = "models/codebert-cross-encoder"

try:
    clf = CodeBERTSQLErrorClassifier(MODEL_DIR)
    model_status = f"Loaded model from `{MODEL_DIR}`"
except Exception as exc:
    clf = None
    model_status = f"Model not loaded: {exc}. Train first or set SPACE_MODEL_DIR."

EXAMPLE = {
    "question": "What is the average score of students in each department?",
    "schema": "students(id, name, score, department_id) | departments(id, name)",
    "student_sql": "SELECT department_id, SUM(score) FROM students GROUP BY department_id",
    "correct_sql": "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
}


def classify(question, schema, student_sql, correct_sql, threshold):
    if clf is None:
        return "Train a model first.", ""
    result = clf.predict(
        question=question.strip(),
        schema=schema.strip(),
        student_sql=student_sql.strip(),
        correct_sql=correct_sql.strip(),
        threshold=threshold,
    )
    summary = (
        f"**{result['primary_label']}** ({result['primary_confidence']:.1%})\n\n"
        f"All labels above threshold: {', '.join(result['error_labels']) or 'none'}"
    )
    probs = "\n".join(
        f"- {k}: {v:.1%}" for k, v in result["probabilities"].items()
    )
    return summary, probs


with gr.Blocks(title="SQL Error Classifier") as demo:
    gr.Markdown(f"# SQL Error Classifier (CodeBERT)\n{model_status}")
    with gr.Row():
        with gr.Column():
            question = gr.Textbox(label="Question", lines=2, value=EXAMPLE["question"])
            schema = gr.Textbox(label="Schema", lines=2, value=EXAMPLE["schema"])
            student_sql = gr.Textbox(label="Student SQL", lines=3, value=EXAMPLE["student_sql"])
            correct_sql = gr.Textbox(label="Correct SQL", lines=3, value=EXAMPLE["correct_sql"])
            threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
            btn = gr.Button("Classify", variant="primary")
        with gr.Column():
            prediction = gr.Markdown()
            probabilities = gr.Markdown()

    btn.click(
        classify,
        [question, schema, student_sql, correct_sql, threshold],
        [prediction, probabilities],
    )

if __name__ == "__main__":
    demo.launch()