File size: 2,822 Bytes
c8ba91f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# train_intent_classifier.py
# MODIFIED
# This script now loads data from the persistent HF Dataset
# using the central dataset_utils.

import os
import json
from pathlib import Path
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer
import torch
from datasets import Dataset

# Import the new data loader
import dataset_utils 

DATA_DIR = Path(os.getenv("DATA_DIR", "./data"))
MODEL_OUT = Path(os.getenv("MODEL_OUT", "./models/intent-classifier"))
BASE_MODEL = os.getenv("BASE_MODEL", "distilbert-base-uncased")
BATCH_SIZE = int(os.getenv("TRAIN_BATCH", "8"))
EPOCHS = int(os.getenv("TRAIN_EPOCHS", "1"))

def load_examples():
    """Loads examples from the central HF Dataset."""
    print("Downloading examples from HF Dataset...")
    return dataset_utils.load_fine_tune_examples()

def build_label_map(examples):
    # ... (this function is unchanged) ...
    labs = sorted({ex.get("label", "general_guidance") for ex in examples})
    return {lab: idx for idx, lab in enumerate(labs)}

def main():
    examples = load_examples()
    if len(examples) < 4:
        print(f"Not enough examples to train (found {len(examples)}). Add more or reduce MIN_EXAMPLES.")
        return

    print(f"Loaded {len(examples)} examples.")
    label_map = build_label_map(examples)
    print("Label map:", label_map)

    # ... (rest of your main() function is unchanged) ...
    texts = [ex["text"] for ex in examples]
    labels = [label_map[ex.get("label", "general_guidance")] for ex in examples]

    tokenizer = DistilBertTokenizerFast.from_pretrained(BASE_MODEL)
    enc = tokenizer(texts, padding=True, truncation=True, max_length=128)

    ds = Dataset.from_dict({
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "labels": labels
    }).map(lambda x: {"labels": x["labels"]})

    model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(label_map))

    training_args = TrainingArguments(
        output_dir=str(MODEL_OUT),
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        save_total_limit=2,
        logging_steps=10,
        remove_unused_columns=False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds,
        tokenizer=tokenizer
    )

    trainer.train()
    MODEL_OUT.mkdir(parents=True, exist_ok=True)
    trainer.save_model(str(MODEL_OUT))
    
    # save label map
    with open(MODEL_OUT / "label_map.json", "w", encoding="utf-8") as f:
        json.dump(label_map, f, ensure_ascii=False, indent=2)
    print("Training complete. Model saved to", MODEL_OUT)

if __name__ == "__main__":
    main()