DiffuMoE / deepspeed_config_and_inference.py
pragadeeshv23's picture
Upload folder using huggingface_hub
05c5c96 verified
"""
DeepSpeed Configuration & Inference Optimization
For RTX 2050 (4GB VRAM) with Arch Linux
"""
# deepspeed_config.json
deepspeed_config = {
"train_batch_size": 16, # global batch size (4 per GPU × 4 accumulation)
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 4,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 5e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01,
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 5e-4,
"warmup_num_steps": 500,
"total_num_steps": 10000,
}
},
"fp16": {
"enabled": True,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 15,
"hysteresis": 2,
},
"zero_optimization": {
"stage": 2, # ZeRO-2 (optimizer states + gradients on CPU)
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
},
"allgather_partitions": True,
"allgather_bucket_size": 5e7,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e7,
"contiguous_gradients": True,
},
"gradient_clipping": 1.0,
"activation_checkpointing": {
"partition_activations": True,
"cpu_checkpointing": True,
"contiguous_memory_optimization": False,
"number_checkpoints": 4,
},
"wall_clock_breakdown": True,
}
import json
with open("deepspeed_config.json", "w") as f:
json.dump(deepspeed_config, f, indent=2)
# ============================================================================
# Optimized Inference for RTX 2050
# ============================================================================
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import gc
from typing import Optional
class OptimizedStudent:
"""Inference-optimized student model wrapper"""
def __init__(self, model_path: str, device: str = 'cuda'):
self.device = device
self.model_path = model_path
# Load with optimizations
self.model = torch.load(model_path, map_location=device)['model_state_dict']
# Note: You'd load into StudentModel class here
# Quantization options
self.quantized = False
self.use_flash_attn = torch.cuda.is_available()
def quantize_int8(self):
"""INT8 quantization for 4GB VRAM"""
# Using bitsandbytes for INT8 quantization
try:
from bitsandbytes.nn import Linear8bitLt
# Replace linear layers with INT8 versions
self.quantized = True
print("Model quantized to INT8")
except ImportError:
print("bitsandbytes not available, skipping INT8 quantization")
def quantize_nf4(self):
"""NF4 quantization (4-bit, even more efficient)"""
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
print("NF4 quantization config ready")
return quantization_config
except ImportError:
print("bitsandbytes not available for NF4")
return None
def inference(
self,
prompt: str,
max_length: int = 128,
temperature: float = 0.7,
top_p: float = 0.95,
) -> str:
"""Optimized inference with KV cache"""
self.model.eval()
with torch.no_grad():
# Tokenize
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
# Generate with minimum memory overhead
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
use_cache=True, # KV cache for speed
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return response
# ============================================================================
# Evaluation Metrics
# ============================================================================
import math
from datasets import load_dataset
class DistillationEvaluator:
"""Comprehensive evaluation metrics"""
def __init__(self, teacher_model, student_model, tokenizer, device):
self.teacher = teacher_model
self.student = student_model
self.tokenizer = tokenizer
self.device = device
def compute_perplexity(self, texts: list) -> float:
"""Perplexity on evaluation set"""
total_loss = 0.0
num_tokens = 0
self.student.eval()
with torch.no_grad():
for text in texts:
inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
outputs = self.student(**inputs)
loss = outputs.loss if hasattr(outputs, 'loss') else 0.0
if loss > 0:
total_loss += loss.item()
num_tokens += inputs['input_ids'].numel()
perplexity = math.exp(total_loss / num_tokens) if num_tokens > 0 else float('inf')
return perplexity
def compute_task_specific_metrics(self, dataset_name: str = "wikitext"):
"""Evaluate on specific tasks (QA, summarization, etc.)"""
metrics = {}
if dataset_name == "wikitext":
dataset = load_dataset("wikitext", "wikitext-2")
perplexity = self.compute_perplexity(dataset['test']['text'][:100])
metrics['wikitext_perplexity'] = perplexity
return metrics
def distillation_fidelity(self, texts: list, top_k: int = 5) -> float:
"""Measure how well student matches teacher predictions"""
match_count = 0
total = 0
self.teacher.eval()
self.student.eval()
with torch.no_grad():
for text in texts:
inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
teacher_logits = self.teacher(**inputs).logits
student_logits = self.student(**inputs)['logits']
# Top-k agreement
teacher_topk = torch.topk(teacher_logits, top_k, dim=-1).indices
student_topk = torch.topk(student_logits, top_k, dim=-1).indices
match = (teacher_topk == student_topk).float().mean().item()
match_count += match
total += 1
fidelity = match_count / total if total > 0 else 0.0
return fidelity
# ============================================================================
# Training Command (with DeepSpeed)
# ============================================================================
"""
To train with DeepSpeed:
deepspeed distill_llm.py \
--deepspeed_config deepspeed_config.json \
--teacher_model mistralai/Mistral-7B-Instruct-v0.1 \
--student_hidden_dim 512 \
--student_num_layers 8 \
--batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 5e-4 \
--max_steps 10000 \
--temperature 4.0 \
--alpha 0.7 \
--beta 0.3
For RTX 2050 (4GB VRAM):
- Use ZeRO-2 with CPU offloading
- Batch size: 4 per GPU (with 4x accumulation)
- fp16 training
- Gradient checkpointing
- INT8 quantization after training (8x compression)
Estimated memory:
- Teacher: 14GB (load with device_map='auto' to split)
- Student: 1.2GB (fp16)
- Optimizer states: 2.4GB (offloaded to CPU)
- Gradients: 1.2GB
- Activations: 0.5GB
- Total on GPU: ~3.5GB ✓ (fits in 4GB)
"""