#!/usr/bin/env python3 """ Train model-based intent classifier for Query Router. Replaces rule-based heuristics with TF-IDF + LogisticRegression (or FastText/DistilBERT). Uses synthetic seed data; extend with real labeled queries via --data CSV. Usage: python scripts/model/train_intent_router.py python scripts/model/train_intent_router.py --data data/intent_labels.csv python scripts/model/train_intent_router.py --backend fasttext python scripts/model/train_intent_router.py --backend distilbert Output: data/model/intent_classifier.pkl (or .bin for fasttext) """ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) import joblib import logging import pandas as pd from src.core.intent_classifier import train_classifier, INTENTS logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) # Synthetic training data: (query, intent) # Extend with real user queries for better generalization SEED_DATA = [ # small_to_big: detail-oriented, plot/review focused ("book with twist ending", "small_to_big"), ("unreliable narrator", "small_to_big"), ("spoiler about the ending", "small_to_big"), ("what did readers think", "small_to_big"), ("opinion on the book", "small_to_big"), ("hidden details in the story", "small_to_big"), ("did anyone cry reading this", "small_to_big"), ("review of the book", "small_to_big"), ("plot twist reveal", "small_to_big"), ("unreliable narrator twist", "small_to_big"), ("readers who loved the ending", "small_to_big"), ("spoiler what happens at the end", "small_to_big"), # fast: short keyword queries ("AI book", "fast"), ("Python", "fast"), ("romance", "fast"), ("machine learning", "fast"), ("science fiction", "fast"), ("best AI book", "fast"), ("Python programming", "fast"), ("self help", "fast"), ("business", "fast"), ("fiction", "fast"), ("thriller", "fast"), ("mystery novel", "fast"), ("finance", "fast"), ("history", "fast"), ("psychology", "fast"), ("data science", "fast"), ("cooking", "fast"), ("music", "fast"), ("art", "fast"), ("philosophy", "fast"), # fast: book titles (keyword-like, BM25 works well) ("War and Peace", "fast"), ("The Lord of the Rings", "fast"), ("Harry Potter", "fast"), ("1984", "fast"), ("To Kill a Mockingbird", "fast"), ("The Great Gatsby", "fast"), ("Pride and Prejudice", "fast"), ("Dune", "fast"), ("Sapiens", "fast"), ("Atomic Habits", "fast"), ("Deep Work", "fast"), # deep: natural language, complex queries ("What are the best books about artificial intelligence for beginners", "deep"), ("I'm looking for something similar to Harry Potter", "deep"), ("Books that help you understand machine learning", "deep"), ("Recommend me a book like Sapiens but about technology", "deep"), ("I want to learn about psychology and human behavior", "deep"), ("What should I read if I liked 1984", "deep"), ("Looking for books on startup founding and entrepreneurship", "deep"), ("Can you suggest books about climate change and sustainability", "deep"), ("I need a book that explains quantum physics simply", "deep"), ("Books for someone who wants to improve their writing skills", "deep"), ("What are some good fiction books set in Japan", "deep"), ("Recommendations for someone getting into philosophy", "deep"), ("Books that discuss the future of work and automation", "deep"), ("I'm interested in biographies of scientists", "deep"), ("Something light and funny for a long flight", "deep"), ("Books about the history of mathematics", "deep"), ("Recommend me novels with strong female protagonists", "deep"), ("What to read to understand economics", "deep"), ("Books on meditation and mindfulness", "deep"), # deep: natural language with book references (need context, not just keyword) ("books like War and Peace", "deep"), ("similar to The Lord of the Rings", "deep"), ("recommend something like Harry Potter", "deep"), ("what to read after 1984", "deep"), ("books similar to Sapiens", "deep"), ] def load_training_data(data_path: Path | None) -> tuple[list[str], list[str]]: """Load (queries, labels) from SEED_DATA + optional CSV.""" queries = [q for q, _ in SEED_DATA] labels = [l for _, l in SEED_DATA] if data_path and data_path.exists(): df = pd.read_csv(data_path) q_col = "query" if "query" in df.columns else df.columns[0] l_col = "intent" if "intent" in df.columns else df.columns[1] extra_q = df[q_col].astype(str).tolist() extra_l = df[l_col].astype(str).tolist() queries.extend(extra_q) labels.extend(extra_l) logger.info("Loaded %d extra samples from %s", len(extra_q), data_path) return queries, labels def main(): import argparse parser = argparse.ArgumentParser(description="Train intent classifier") parser.add_argument("--data", type=Path, default=None, help="CSV with query,intent columns") parser.add_argument("--backend", choices=["tfidf", "fasttext", "distilbert"], default="tfidf") args = parser.parse_args() project_root = Path(__file__).resolve().parent.parent.parent out_dir = project_root / "data" / "model" out_dir.mkdir(parents=True, exist_ok=True) queries, labels = load_training_data(args.data) logger.info("Training intent classifier (%s) on %d samples...", args.backend, len(queries)) result = train_classifier(queries, labels, backend=args.backend) if args.backend == "fasttext": out_path = out_dir / "intent_classifier.bin" result.save_model(str(out_path)) else: out_path = out_dir / "intent_classifier.pkl" if args.backend == "distilbert": joblib.dump(result, out_path) # dict with pipeline, backend, etc. else: joblib.dump({"pipeline": result, "backend": "tfidf"}, out_path) logger.info("Saved to %s", out_path) # Quick sanity check for intent in INTENTS: sample = next((q for q, l in zip(queries, labels) if l == intent), None) if sample: if args.backend == "fasttext": pred = result.predict(sample)[0][0].replace("__label__", "") elif args.backend == "distilbert": from transformers import pipeline pipe = pipeline("zero-shot-classification", model="distilbert-base-uncased", device=-1) pred = pipe(sample, INTENTS, multi_label=False)["labels"][0] else: pred = result.predict([sample])[0] ok = "✓" if pred == intent else "✗" logger.info(" %s %s: %r -> %s", ok, intent, sample[:40], pred) if __name__ == "__main__": main()