#!/usr/bin/env python3 from qwen_distill import QwenDistillationConfig, QwenDistillationTrainer, TextDataset, load_training_texts from torch.utils.data import DataLoader import torch # Load config config = QwenDistillationConfig() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize trainer trainer = QwenDistillationTrainer(config, device) # Load data texts = load_training_texts(config.data_file) print(f"Loaded {len(texts)} cleaned text samples from {config.data_file}") # Create dataset & dataloader dataset = TextDataset(texts, trainer.tokenizer, max_length=config.max_seq_length) dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=0) # Train trainer.train(dataloader) print("✓ Training complete!") print(f"Student saved to: checkpoints/student_final.pt")