| """ | |
| Predict query intent for one or more queries. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from config import MODEL_DIR, INTENT_LABELS | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--query", type=str, help="Single query") | |
| parser.add_argument("--input", type=str, help="CSV with 'query' column") | |
| parser.add_argument("--output", type=str, default="predictions.csv") | |
| args = parser.parse_args() | |
| if not (MODEL_DIR / "config.json").exists(): | |
| raise FileNotFoundError(f"Train first. No model in {MODEL_DIR}") | |
| tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR)) | |
| model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR)) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| def predict(queries: list[str]): | |
| inp = tokenizer(queries, truncation=True, max_length=128, padding=True, return_tensors="pt") | |
| inp = {k: v.to(device) for k, v in inp.items()} | |
| with torch.no_grad(): | |
| out = model(**inp) | |
| return out.logits.argmax(dim=1).tolist() | |
| if args.query: | |
| idx = predict([args.query])[0] | |
| print({"query": args.query, "intent": INTENT_LABELS[idx]}) | |
| return | |
| if args.input and Path(args.input).exists(): | |
| df = pd.read_csv(args.input) | |
| if "query" not in df.columns: | |
| raise ValueError("CSV must have 'query' column") | |
| indices = predict(df["query"].astype(str).tolist()) | |
| df["intent"] = [INTENT_LABELS[i] for i in indices] | |
| df.to_csv(args.output, index=False) | |
| print(f"Saved to {args.output}") | |
| return | |
| print("Use --query 'text' or --input file.csv") | |
| if __name__ == "__main__": | |
| main() | |