Spaces:
Sleeping
Sleeping
| # train_intent_classifier.py | |
| # MODIFIED | |
| # This script now loads data from the persistent HF Dataset | |
| # using the central dataset_utils. | |
| import os | |
| import json | |
| from pathlib import Path | |
| from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, TrainingArguments, Trainer | |
| import torch | |
| from datasets import Dataset | |
| # Import the new data loader | |
| import dataset_utils | |
| DATA_DIR = Path(os.getenv("DATA_DIR", "./data")) | |
| MODEL_OUT = Path(os.getenv("MODEL_OUT", "./models/intent-classifier")) | |
| BASE_MODEL = os.getenv("BASE_MODEL", "distilbert-base-uncased") | |
| BATCH_SIZE = int(os.getenv("TRAIN_BATCH", "8")) | |
| EPOCHS = int(os.getenv("TRAIN_EPOCHS", "1")) | |
| def load_examples(): | |
| """Loads examples from the central HF Dataset.""" | |
| print("Downloading examples from HF Dataset...") | |
| return dataset_utils.load_fine_tune_examples() | |
| def build_label_map(examples): | |
| # ... (this function is unchanged) ... | |
| labs = sorted({ex.get("label", "general_guidance") for ex in examples}) | |
| return {lab: idx for idx, lab in enumerate(labs)} | |
| def main(): | |
| examples = load_examples() | |
| if len(examples) < 4: | |
| print(f"Not enough examples to train (found {len(examples)}). Add more or reduce MIN_EXAMPLES.") | |
| return | |
| print(f"Loaded {len(examples)} examples.") | |
| label_map = build_label_map(examples) | |
| print("Label map:", label_map) | |
| # ... (rest of your main() function is unchanged) ... | |
| texts = [ex["text"] for ex in examples] | |
| labels = [label_map[ex.get("label", "general_guidance")] for ex in examples] | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(BASE_MODEL) | |
| enc = tokenizer(texts, padding=True, truncation=True, max_length=128) | |
| ds = Dataset.from_dict({ | |
| "input_ids": enc["input_ids"], | |
| "attention_mask": enc["attention_mask"], | |
| "labels": labels | |
| }).map(lambda x: {"labels": x["labels"]}) | |
| model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(label_map)) | |
| training_args = TrainingArguments( | |
| output_dir=str(MODEL_OUT), | |
| num_train_epochs=EPOCHS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| save_total_limit=2, | |
| logging_steps=10, | |
| remove_unused_columns=False | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=ds, | |
| tokenizer=tokenizer | |
| ) | |
| trainer.train() | |
| MODEL_OUT.mkdir(parents=True, exist_ok=True) | |
| trainer.save_model(str(MODEL_OUT)) | |
| # save label map | |
| with open(MODEL_OUT / "label_map.json", "w", encoding="utf-8") as f: | |
| json.dump(label_map, f, ensure_ascii=False, indent=2) | |
| print("Training complete. Model saved to", MODEL_OUT) | |
| if __name__ == "__main__": | |
| main() |