# train_intent_classifier.py # MODIFIED # This script now loads data from the persistent HF Dataset # using the central dataset_utils. import os import json from pathlib import Path from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer import torch from datasets import Dataset # Import the new data loader import dataset_utils DATA_DIR = Path(os.getenv("DATA_DIR", "./data")) MODEL_OUT = Path(os.getenv("MODEL_OUT", "./models/intent-classifier")) BASE_MODEL = os.getenv("BASE_MODEL", "distilbert-base-uncased") BATCH_SIZE = int(os.getenv("TRAIN_BATCH", "8")) EPOCHS = int(os.getenv("TRAIN_EPOCHS", "1")) def load_examples(): """Loads examples from the central HF Dataset.""" print("Downloading examples from HF Dataset...") return dataset_utils.load_fine_tune_examples() def build_label_map(examples): # ... (this function is unchanged) ... labs = sorted({ex.get("label", "general_guidance") for ex in examples}) return {lab: idx for idx, lab in enumerate(labs)} def main(): examples = load_examples() if len(examples) < 4: print(f"Not enough examples to train (found {len(examples)}). Add more or reduce MIN_EXAMPLES.") return print(f"Loaded {len(examples)} examples.") label_map = build_label_map(examples) print("Label map:", label_map) # ... (rest of your main() function is unchanged) ... texts = [ex["text"] for ex in examples] labels = [label_map[ex.get("label", "general_guidance")] for ex in examples] tokenizer = DistilBertTokenizerFast.from_pretrained(BASE_MODEL) enc = tokenizer(texts, padding=True, truncation=True, max_length=128) ds = Dataset.from_dict({ "input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": labels }).map(lambda x: {"labels": x["labels"]}) model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(label_map)) training_args = TrainingArguments( output_dir=str(MODEL_OUT), num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, save_total_limit=2, logging_steps=10, remove_unused_columns=False ) trainer = Trainer( model=model, args=training_args, train_dataset=ds, tokenizer=tokenizer ) trainer.train() MODEL_OUT.mkdir(parents=True, exist_ok=True) trainer.save_model(str(MODEL_OUT)) # save label map with open(MODEL_OUT / "label_map.json", "w", encoding="utf-8") as f: json.dump(label_map, f, ensure_ascii=False, indent=2) print("Training complete. Model saved to", MODEL_OUT) if __name__ == "__main__": main()