Spaces:
Sleeping
Sleeping
File size: 2,458 Bytes
8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded 8464aea 9b2cded | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """CodeBERT inference Gradio app (optional HF Space for predictions)."""
from __future__ import annotations
import gradio as gr
from src.hf_predict_codebert import CodeBERTSQLErrorClassifier
MODEL_DIR = "models/codebert-cross-encoder"
try:
clf = CodeBERTSQLErrorClassifier(MODEL_DIR)
model_status = f"Loaded model from `{MODEL_DIR}`"
except Exception as exc:
clf = None
model_status = f"Model not loaded: {exc}. Train first or set SPACE_MODEL_DIR."
EXAMPLE = {
"question": "What is the average score of students in each department?",
"schema": "students(id, name, score, department_id) | departments(id, name)",
"student_sql": "SELECT department_id, SUM(score) FROM students GROUP BY department_id",
"correct_sql": "SELECT department_id, AVG(score) FROM students GROUP BY department_id",
}
def classify(question, schema, student_sql, correct_sql, threshold):
if clf is None:
return "Train a model first.", ""
result = clf.predict(
question=question.strip(),
schema=schema.strip(),
student_sql=student_sql.strip(),
correct_sql=correct_sql.strip(),
threshold=threshold,
)
summary = (
f"**{result['primary_label']}** ({result['primary_confidence']:.1%})\n\n"
f"All labels above threshold: {', '.join(result['error_labels']) or 'none'}"
)
probs = "\n".join(
f"- {k}: {v:.1%}" for k, v in result["probabilities"].items()
)
return summary, probs
with gr.Blocks(title="SQL Error Classifier") as demo:
gr.Markdown(f"# SQL Error Classifier (CodeBERT)\n{model_status}")
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Question", lines=2, value=EXAMPLE["question"])
schema = gr.Textbox(label="Schema", lines=2, value=EXAMPLE["schema"])
student_sql = gr.Textbox(label="Student SQL", lines=3, value=EXAMPLE["student_sql"])
correct_sql = gr.Textbox(label="Correct SQL", lines=3, value=EXAMPLE["correct_sql"])
threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
btn = gr.Button("Classify", variant="primary")
with gr.Column():
prediction = gr.Markdown()
probabilities = gr.Markdown()
btn.click(
classify,
[question, schema, student_sql, correct_sql, threshold],
[prediction, probabilities],
)
if __name__ == "__main__":
demo.launch()
|