DiffuMoE / train.py
pragadeeshv23's picture
Upload folder using huggingface_hub
05c5c96 verified
raw
history blame contribute delete
839 Bytes
#!/usr/bin/env python3
from qwen_distill import QwenDistillationConfig, QwenDistillationTrainer, TextDataset, load_training_texts
from torch.utils.data import DataLoader
import torch
# Load config
config = QwenDistillationConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize trainer
trainer = QwenDistillationTrainer(config, device)
# Load data
texts = load_training_texts(config.data_file)
print(f"Loaded {len(texts)} cleaned text samples from {config.data_file}")
# Create dataset & dataloader
dataset = TextDataset(texts, trainer.tokenizer, max_length=config.max_seq_length)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
# Train
trainer.train(dataloader)
print("✓ Training complete!")
print(f"Student saved to: checkpoints/student_final.pt")