| |
| """ |
| Inference & Evaluation for Qwen-0.8B Student Model |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
| from pathlib import Path |
| import logging |
| import time |
| from typing import Dict, List |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class StudentInference: |
| """Run inference with distilled student model""" |
| |
| def __init__(self, checkpoint_path: str, device: str = "cuda"): |
| self.device = torch.device(device) |
| self.checkpoint_path = checkpoint_path |
| |
| logger.info(f"Loading checkpoint: {checkpoint_path}") |
| self.checkpoint = torch.load(checkpoint_path, map_location=device) |
| self.config = self.checkpoint['config'] |
| |
| |
| from qwen_distill import QwenDistillationConfig, QwenStudentModel |
| |
| config_obj = QwenDistillationConfig() |
| for key, val in self.config.items(): |
| setattr(config_obj, key, val) |
| |
| self.model = QwenStudentModel(config_obj).to(device) |
| self.model.load_state_dict(self.checkpoint['model_state_dict']) |
| self.model.eval() |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| config_obj.teacher_model_name, |
| trust_remote_code=True, |
| ) |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| logger.info(f"✓ Model loaded. Parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M") |
| |
| def generate( |
| self, |
| prompt: str, |
| max_length: int = 100, |
| temperature: float = 0.7, |
| top_p: float = 0.95, |
| ) -> str: |
| """Generate text from prompt""" |
| |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
| |
| with torch.no_grad(): |
| for _ in range(max_length): |
| outputs = self.model(input_ids) |
| logits = outputs['logits'][:, -1, :] |
| |
| |
| logits = logits / temperature |
| |
| |
| probs = F.softmax(logits, dim=-1) |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| cumsum_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
| |
| sorted_indices_to_remove = cumsum_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| logits[0, indices_to_remove] = -float('inf') |
| |
| |
| next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
| |
| if next_token.item() == self.tokenizer.eos_token_id: |
| break |
| |
| return self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| |
| def inference_speed_test(self, prompt: str = "The future of AI", num_runs: int = 10): |
| """Benchmark inference speed""" |
| logger.info(f"Running speed test ({num_runs} iterations)...") |
| |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
| |
| |
| with torch.no_grad(): |
| _ = self.model(input_ids) |
| |
| |
| times = [] |
| with torch.no_grad(): |
| for _ in range(num_runs): |
| torch.cuda.synchronize() |
| start = time.time() |
| _ = self.model(input_ids) |
| torch.cuda.synchronize() |
| times.append(time.time() - start) |
| |
| avg_time = sum(times) / len(times) * 1000 |
| logger.info(f"Average inference time: {avg_time:.1f}ms") |
| logger.info(f"Throughput: {1000/avg_time:.1f} samples/sec") |
| |
| return { |
| 'avg_time_ms': avg_time, |
| 'throughput': 1000 / avg_time, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class StudentEvaluator: |
| """Evaluate student model quality""" |
| |
| def __init__(self, student_checkpoint: str, teacher_model_name: str, device: str = "cuda"): |
| self.device = torch.device(device) |
| self.student_inf = StudentInference(student_checkpoint, device) |
| |
| |
| from transformers import AutoModelForCausalLM |
| logger.info(f"Loading teacher: {teacher_model_name}") |
| |
| self.teacher = AutoModelForCausalLM.from_pretrained( |
| teacher_model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| self.teacher.eval() |
| |
| self.tokenizer = self.student_inf.tokenizer |
| |
| def compute_perplexity(self, texts: List[str], max_length: int = 256) -> float: |
| """Compute perplexity on text samples""" |
| total_loss = 0.0 |
| num_tokens = 0 |
| |
| self.student_inf.model.eval() |
| |
| with torch.no_grad(): |
| for text in texts: |
| enc = self.tokenizer( |
| text, |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt", |
| ).to(self.device) |
| |
| outputs = self.student_inf.model(enc['input_ids']) |
| logits = outputs['logits'] |
| |
| |
| loss = F.cross_entropy( |
| logits[0, :-1, :], |
| enc['input_ids'][0, 1:], |
| reduction='mean' |
| ) |
| |
| total_loss += loss.item() |
| num_tokens += enc['input_ids'].numel() |
| |
| perplexity = torch.exp(torch.tensor(total_loss / len(texts))).item() |
| logger.info(f"Student perplexity: {perplexity:.2f}") |
| return perplexity |
| |
| def compute_teacher_perplexity(self, texts: List[str], max_length: int = 256) -> float: |
| """Compute perplexity on teacher for comparison""" |
| total_loss = 0.0 |
| |
| self.teacher.eval() |
| |
| with torch.no_grad(): |
| for text in texts: |
| enc = self.tokenizer( |
| text, |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt", |
| ).to(self.device) |
| |
| outputs = self.teacher(enc['input_ids'], output_hidden_states=True) |
| logits = outputs.logits |
| |
| loss = F.cross_entropy( |
| logits[0, :-1, :], |
| enc['input_ids'][0, 1:], |
| reduction='mean' |
| ) |
| |
| total_loss += loss.item() |
| |
| perplexity = torch.exp(torch.tensor(total_loss / len(texts))).item() |
| logger.info(f"Teacher perplexity: {perplexity:.2f}") |
| return perplexity |
| |
| def top_k_agreement(self, texts: List[str], k: int = 5) -> float: |
| """Measure how well student matches teacher top-k predictions""" |
| match_count = 0 |
| total = 0 |
| |
| self.student_inf.model.eval() |
| self.teacher.eval() |
| |
| with torch.no_grad(): |
| for text in texts: |
| enc = self.tokenizer( |
| text, |
| return_tensors="pt", |
| max_length=256, |
| truncation=True, |
| ).to(self.device) |
| |
| student_out = self.student_inf.model(enc['input_ids']) |
| student_logits = student_out['logits'] |
| |
| teacher_out = self.teacher(enc['input_ids']) |
| teacher_logits = teacher_out.logits |
| |
| |
| _, student_topk = torch.topk(student_logits, k, dim=-1) |
| _, teacher_topk = torch.topk(teacher_logits, k, dim=-1) |
| |
| |
| matches = (student_topk == teacher_topk).float().sum().item() |
| match_count += matches |
| total += student_topk.numel() |
| |
| agreement = match_count / total if total > 0 else 0.0 |
| logger.info(f"Top-{k} agreement with teacher: {agreement*100:.1f}%") |
| return agreement |
| |
| def generate_comparison(self, prompt: str = "The future of AI", max_length: int = 100): |
| """Compare student vs teacher generation""" |
| logger.info(f"\nPrompt: {prompt}\n") |
| |
| |
| student_text = self.student_inf.generate(prompt, max_length=max_length) |
| logger.info(f"Student:\n{student_text}\n") |
| |
| |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| outputs = self.teacher.generate( |
| input_ids, |
| max_length=max_length, |
| temperature=0.7, |
| top_p=0.95, |
| ) |
| teacher_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| logger.info(f"Teacher:\n{teacher_text}\n") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/student_final.pt", help="Student checkpoint path") |
| parser.add_argument("--teacher", default="Qwen/Qwen2.5-0.5B", help="Teacher model name") |
| parser.add_argument("--prompt", default="The future of artificial intelligence", help="Generation prompt") |
| parser.add_argument("--speed", action="store_true", help="Run speed test") |
| parser.add_argument("--eval", action="store_true", help="Run evaluation") |
| |
| args = parser.parse_args() |
| |
| |
| logger.info("Loading student model...") |
| inference = StudentInference(args.checkpoint) |
| |
| logger.info(f"Generating from prompt: {args.prompt}\n") |
| text = inference.generate(args.prompt, max_length=100) |
| print(text) |
| |
| if args.speed: |
| logger.info("\nBenchmarking speed...") |
| inference.inference_speed_test() |
| |
| if args.eval: |
| logger.info("\nRunning evaluation...") |
| evaluator = StudentEvaluator(args.checkpoint, args.teacher) |
| |
| |
| test_texts = [ |
| "Artificial intelligence is transforming industries.", |
| "Machine learning models require careful tuning.", |
| "Distillation compresses large models efficiently.", |
| ] |
| |
| evaluator.compute_perplexity(test_texts) |
| evaluator.compute_teacher_perplexity(test_texts) |
| evaluator.top_k_agreement(test_texts, k=5) |
| evaluator.generate_comparison(args.prompt, max_length=100) |
|
|