"""Inference API for SQL error classification.""" from __future__ import annotations import argparse import json from dataclasses import asdict, dataclass from pathlib import Path from typing import List, Optional from src.categories import id_to_name, load_categories from src.model import DEFAULT_MODEL_PATH, combine_features, load_model from src.cross_encoder_model import ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, ) from src.multi_tower_model import MultiTowerClassifier, QueryContext CONTEXT_MODELS = ( CrossEncoderClassifier, FineTunedCrossEncoderClassifier, MultiTowerClassifier, ) @dataclass class Prediction: label_id: int label_name: str confidence: float top_k: List[dict] similarities: Optional[dict] = None pair_scores: Optional[dict] = None class SQLErrorClassifier: """Classifier wrapper for playground integration.""" def __init__(self, model_path: Path = DEFAULT_MODEL_PATH): self.model = load_model(model_path) self.label_map = id_to_name(load_categories()) def predict( self, query: str, error_message: Optional[str] = None, schema: Optional[str] = None, question: Optional[str] = None, correct_query: Optional[str] = None, top_k: int = 3, ) -> Prediction: if isinstance(self.model, CONTEXT_MODELS): if not all([schema, question, correct_query]): raise ValueError( "context models require schema, question, and correct_query" ) ctx = QueryContext( question=question, schema=schema, correct_query=correct_query, student_query=query, error_message=error_message, ) proba = self.model.predict_proba([ctx])[0] similarities = ( self.model.explain_similarities(ctx) if isinstance(self.model, MultiTowerClassifier) else None ) pair_scores = ( self.model.explain_pair_scores(ctx) if isinstance(self.model, CrossEncoderClassifier) else None ) else: pair_scores = None similarities = None text = combine_features( queries=[query], error_messages=[error_message] if error_message else None, schemas=[schema] if schema else None, questions=[question] if question else None, )[0] proba = self.model.predict_proba([text])[0] similarities = None classes = self.model.classes_ ranked = sorted(zip(classes, proba), key=lambda x: x[1], reverse=True) best_id = int(ranked[0][0]) return Prediction( label_id=best_id, label_name=self.label_map[best_id], confidence=float(ranked[0][1]), top_k=[ { "label_id": int(cls), "label_name": self.label_map[int(cls)], "confidence": float(p), } for cls, p in ranked[:top_k] ], similarities=similarities, pair_scores=pair_scores, ) def main() -> None: parser = argparse.ArgumentParser(description="Classify SQL error type") parser.add_argument("--query", type=str, required=True) parser.add_argument("--correct-query", type=str, default=None) parser.add_argument("--error-message", type=str, default=None) parser.add_argument("--schema", type=str, default=None) parser.add_argument("--question", type=str, default=None) parser.add_argument("--model", type=Path, default=DEFAULT_MODEL_PATH) parser.add_argument("--top-k", type=int, default=3) args = parser.parse_args() clf = SQLErrorClassifier(args.model) result = clf.predict( args.query, args.error_message, args.schema, args.question, args.correct_query, top_k=args.top_k, ) print(json.dumps(asdict(result), indent=2)) if __name__ == "__main__": main()