File size: 839 Bytes
05c5c96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | #!/usr/bin/env python3
from qwen_distill import QwenDistillationConfig, QwenDistillationTrainer, TextDataset, load_training_texts
from torch.utils.data import DataLoader
import torch
# Load config
config = QwenDistillationConfig()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize trainer
trainer = QwenDistillationTrainer(config, device)
# Load data
texts = load_training_texts(config.data_file)
print(f"Loaded {len(texts)} cleaned text samples from {config.data_file}")
# Create dataset & dataloader
dataset = TextDataset(texts, trainer.tokenizer, max_length=config.max_seq_length)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0)
# Train
trainer.train(dataloader)
print("✓ Training complete!")
print(f"Student saved to: checkpoints/student_final.pt")
|