DiffuMoE / qwen_inference.py
pragadeeshv23's picture
Upload folder using huggingface_hub
05c5c96 verified
#!/usr/bin/env python3
"""
Inference & Evaluation for Qwen-0.8B Student Model
"""
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from pathlib import Path
import logging
import time
from typing import Dict, List
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# INFERENCE
# ============================================================================
class StudentInference:
"""Run inference with distilled student model"""
def __init__(self, checkpoint_path: str, device: str = "cuda"):
self.device = torch.device(device)
self.checkpoint_path = checkpoint_path
logger.info(f"Loading checkpoint: {checkpoint_path}")
self.checkpoint = torch.load(checkpoint_path, map_location=device)
self.config = self.checkpoint['config']
# Reconstruct student model
from qwen_distill import QwenDistillationConfig, QwenStudentModel
config_obj = QwenDistillationConfig()
for key, val in self.config.items():
setattr(config_obj, key, val)
self.model = QwenStudentModel(config_obj).to(device)
self.model.load_state_dict(self.checkpoint['model_state_dict'])
self.model.eval()
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config_obj.teacher_model_name,
trust_remote_code=True,
)
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info(f"✓ Model loaded. Parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M")
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_p: float = 0.95,
) -> str:
"""Generate text from prompt"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
for _ in range(max_length):
outputs = self.model(input_ids)
logits = outputs['logits'][:, -1, :]
# Temperature scaling
logits = logits / temperature
# Top-p sampling
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability > top_p
sorted_indices_to_remove = cumsum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = -float('inf')
# Sample
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == self.tokenizer.eos_token_id:
break
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
def inference_speed_test(self, prompt: str = "The future of AI", num_runs: int = 10):
"""Benchmark inference speed"""
logger.info(f"Running speed test ({num_runs} iterations)...")
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# Warmup
with torch.no_grad():
_ = self.model(input_ids)
# Measure
times = []
with torch.no_grad():
for _ in range(num_runs):
torch.cuda.synchronize()
start = time.time()
_ = self.model(input_ids)
torch.cuda.synchronize()
times.append(time.time() - start)
avg_time = sum(times) / len(times) * 1000 # ms
logger.info(f"Average inference time: {avg_time:.1f}ms")
logger.info(f"Throughput: {1000/avg_time:.1f} samples/sec")
return {
'avg_time_ms': avg_time,
'throughput': 1000 / avg_time,
}
# ============================================================================
# EVALUATION
# ============================================================================
class StudentEvaluator:
"""Evaluate student model quality"""
def __init__(self, student_checkpoint: str, teacher_model_name: str, device: str = "cuda"):
self.device = torch.device(device)
self.student_inf = StudentInference(student_checkpoint, device)
# Load teacher
from transformers import AutoModelForCausalLM
logger.info(f"Loading teacher: {teacher_model_name}")
self.teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
self.teacher.eval()
self.tokenizer = self.student_inf.tokenizer
def compute_perplexity(self, texts: List[str], max_length: int = 256) -> float:
"""Compute perplexity on text samples"""
total_loss = 0.0
num_tokens = 0
self.student_inf.model.eval()
with torch.no_grad():
for text in texts:
enc = self.tokenizer(
text,
max_length=max_length,
truncation=True,
return_tensors="pt",
).to(self.device)
outputs = self.student_inf.model(enc['input_ids'])
logits = outputs['logits']
# Compute cross-entropy loss
loss = F.cross_entropy(
logits[0, :-1, :],
enc['input_ids'][0, 1:],
reduction='mean'
)
total_loss += loss.item()
num_tokens += enc['input_ids'].numel()
perplexity = torch.exp(torch.tensor(total_loss / len(texts))).item()
logger.info(f"Student perplexity: {perplexity:.2f}")
return perplexity
def compute_teacher_perplexity(self, texts: List[str], max_length: int = 256) -> float:
"""Compute perplexity on teacher for comparison"""
total_loss = 0.0
self.teacher.eval()
with torch.no_grad():
for text in texts:
enc = self.tokenizer(
text,
max_length=max_length,
truncation=True,
return_tensors="pt",
).to(self.device)
outputs = self.teacher(enc['input_ids'], output_hidden_states=True)
logits = outputs.logits
loss = F.cross_entropy(
logits[0, :-1, :],
enc['input_ids'][0, 1:],
reduction='mean'
)
total_loss += loss.item()
perplexity = torch.exp(torch.tensor(total_loss / len(texts))).item()
logger.info(f"Teacher perplexity: {perplexity:.2f}")
return perplexity
def top_k_agreement(self, texts: List[str], k: int = 5) -> float:
"""Measure how well student matches teacher top-k predictions"""
match_count = 0
total = 0
self.student_inf.model.eval()
self.teacher.eval()
with torch.no_grad():
for text in texts:
enc = self.tokenizer(
text,
return_tensors="pt",
max_length=256,
truncation=True,
).to(self.device)
student_out = self.student_inf.model(enc['input_ids'])
student_logits = student_out['logits']
teacher_out = self.teacher(enc['input_ids'])
teacher_logits = teacher_out.logits
# Top-k tokens
_, student_topk = torch.topk(student_logits, k, dim=-1)
_, teacher_topk = torch.topk(teacher_logits, k, dim=-1)
# Count matches
matches = (student_topk == teacher_topk).float().sum().item()
match_count += matches
total += student_topk.numel()
agreement = match_count / total if total > 0 else 0.0
logger.info(f"Top-{k} agreement with teacher: {agreement*100:.1f}%")
return agreement
def generate_comparison(self, prompt: str = "The future of AI", max_length: int = 100):
"""Compare student vs teacher generation"""
logger.info(f"\nPrompt: {prompt}\n")
# Student generation
student_text = self.student_inf.generate(prompt, max_length=max_length)
logger.info(f"Student:\n{student_text}\n")
# Teacher generation
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.teacher.generate(
input_ids,
max_length=max_length,
temperature=0.7,
top_p=0.95,
)
teacher_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Teacher:\n{teacher_text}\n")
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/student_final.pt", help="Student checkpoint path")
parser.add_argument("--teacher", default="Qwen/Qwen2.5-0.5B", help="Teacher model name")
parser.add_argument("--prompt", default="The future of artificial intelligence", help="Generation prompt")
parser.add_argument("--speed", action="store_true", help="Run speed test")
parser.add_argument("--eval", action="store_true", help="Run evaluation")
args = parser.parse_args()
# Simple generation
logger.info("Loading student model...")
inference = StudentInference(args.checkpoint)
logger.info(f"Generating from prompt: {args.prompt}\n")
text = inference.generate(args.prompt, max_length=100)
print(text)
if args.speed:
logger.info("\nBenchmarking speed...")
inference.inference_speed_test()
if args.eval:
logger.info("\nRunning evaluation...")
evaluator = StudentEvaluator(args.checkpoint, args.teacher)
# Test data
test_texts = [
"Artificial intelligence is transforming industries.",
"Machine learning models require careful tuning.",
"Distillation compresses large models efficiently.",
]
evaluator.compute_perplexity(test_texts)
evaluator.compute_teacher_perplexity(test_texts)
evaluator.top_k_agreement(test_texts, k=5)
evaluator.generate_comparison(args.prompt, max_length=100)