query-intent-entity-ner / inference.py
syeedalireza's picture
Upload folder using huggingface_hub
fe67d4b verified
"""
Predict query intent for one or more queries.
"""
import argparse
from pathlib import Path
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from config import MODEL_DIR, INTENT_LABELS
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--query", type=str, help="Single query")
parser.add_argument("--input", type=str, help="CSV with 'query' column")
parser.add_argument("--output", type=str, default="predictions.csv")
args = parser.parse_args()
if not (MODEL_DIR / "config.json").exists():
raise FileNotFoundError(f"Train first. No model in {MODEL_DIR}")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR))
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_DIR))
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
def predict(queries: list[str]):
inp = tokenizer(queries, truncation=True, max_length=128, padding=True, return_tensors="pt")
inp = {k: v.to(device) for k, v in inp.items()}
with torch.no_grad():
out = model(**inp)
return out.logits.argmax(dim=1).tolist()
if args.query:
idx = predict([args.query])[0]
print({"query": args.query, "intent": INTENT_LABELS[idx]})
return
if args.input and Path(args.input).exists():
df = pd.read_csv(args.input)
if "query" not in df.columns:
raise ValueError("CSV must have 'query' column")
indices = predict(df["query"].astype(str).tolist())
df["intent"] = [INTENT_LABELS[i] for i in indices]
df.to_csv(args.output, index=False)
print(f"Saved to {args.output}")
return
print("Use --query 'text' or --input file.csv")
if __name__ == "__main__":
main()