""" Predict query intent for one or more queries. """ import argparse from pathlib import Path import pandas as pd import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from config import MODEL_DIR, INTENT_LABELS def main(): parser = argparse.ArgumentParser() parser.add_argument("--query", type=str, help="Single query") parser.add_argument("--input", type=str, help="CSV with 'query' column") parser.add_argument("--output", type=str, default="predictions.csv") args = parser.parse_args() if not (MODEL_DIR / "config.json").exists(): raise FileNotFoundError(f"Train first. No model in {MODEL_DIR}") tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR)) model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR)) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) def predict(queries: list[str]): inp = tokenizer(queries, truncation=True, max_length=128, padding=True, return_tensors="pt") inp = {k: v.to(device) for k, v in inp.items()} with torch.no_grad(): out = model(**inp) return out.logits.argmax(dim=1).tolist() if args.query: idx = predict([args.query])[0] print({"query": args.query, "intent": INTENT_LABELS[idx]}) return if args.input and Path(args.input).exists(): df = pd.read_csv(args.input) if "query" not in df.columns: raise ValueError("CSV must have 'query' column") indices = predict(df["query"].astype(str).tolist()) df["intent"] = [INTENT_LABELS[i] for i in indices] df.to_csv(args.output, index=False) print(f"Saved to {args.output}") return print("Use --query 'text' or --input file.csv") if __name__ == "__main__": main()