File size: 8,361 Bytes
05c5c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
DeepSpeed Configuration & Inference Optimization
For RTX 2050 (4GB VRAM) with Arch Linux
"""

# deepspeed_config.json
deepspeed_config = {
    "train_batch_size": 16,  # global batch size (4 per GPU × 4 accumulation)
    "train_micro_batch_size_per_gpu": 4,
    "gradient_accumulation_steps": 4,
    
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 5e-4,
            "betas": [0.9, 0.999],
            "eps": 1e-8,
            "weight_decay": 0.01,
        }
    },
    
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 5e-4,
            "warmup_num_steps": 500,
            "total_num_steps": 10000,
        }
    },
    
    "fp16": {
        "enabled": True,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 15,
        "hysteresis": 2,
    },
    
    "zero_optimization": {
        "stage": 2,  # ZeRO-2 (optimizer states + gradients on CPU)
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True,
        },
        "allgather_partitions": True,
        "allgather_bucket_size": 5e7,
        "overlap_comm": True,
        "reduce_scatter": True,
        "reduce_bucket_size": 5e7,
        "contiguous_gradients": True,
    },
    
    "gradient_clipping": 1.0,
    
    "activation_checkpointing": {
        "partition_activations": True,
        "cpu_checkpointing": True,
        "contiguous_memory_optimization": False,
        "number_checkpoints": 4,
    },
    
    "wall_clock_breakdown": True,
}

import json
with open("deepspeed_config.json", "w") as f:
    json.dump(deepspeed_config, f, indent=2)


# ============================================================================
# Optimized Inference for RTX 2050
# ============================================================================

import torch
import torch.nn as nn
from transformers import AutoTokenizer
import gc
from typing import Optional


class OptimizedStudent:
    """Inference-optimized student model wrapper"""
    
    def __init__(self, model_path: str, device: str = 'cuda'):
        self.device = device
        self.model_path = model_path
        
        # Load with optimizations
        self.model = torch.load(model_path, map_location=device)['model_state_dict']
        # Note: You'd load into StudentModel class here
        
        # Quantization options
        self.quantized = False
        self.use_flash_attn = torch.cuda.is_available()
    
    def quantize_int8(self):
        """INT8 quantization for 4GB VRAM"""
        # Using bitsandbytes for INT8 quantization
        try:
            from bitsandbytes.nn import Linear8bitLt
            # Replace linear layers with INT8 versions
            self.quantized = True
            print("Model quantized to INT8")
        except ImportError:
            print("bitsandbytes not available, skipping INT8 quantization")
    
    def quantize_nf4(self):
        """NF4 quantization (4-bit, even more efficient)"""
        try:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
            )
            print("NF4 quantization config ready")
            return quantization_config
        except ImportError:
            print("bitsandbytes not available for NF4")
            return None
    
    def inference(
        self,
        prompt: str,
        max_length: int = 128,
        temperature: float = 0.7,
        top_p: float = 0.95,
    ) -> str:
        """Optimized inference with KV cache"""
        self.model.eval()
        
        with torch.no_grad():
            # Tokenize
            inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
            
            # Generate with minimum memory overhead
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                use_cache=True,  # KV cache for speed
            )
            
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Cleanup
        gc.collect()
        torch.cuda.empty_cache()
        
        return response


# ============================================================================
# Evaluation Metrics
# ============================================================================

import math
from datasets import load_dataset


class DistillationEvaluator:
    """Comprehensive evaluation metrics"""
    
    def __init__(self, teacher_model, student_model, tokenizer, device):
        self.teacher = teacher_model
        self.student = student_model
        self.tokenizer = tokenizer
        self.device = device
    
    def compute_perplexity(self, texts: list) -> float:
        """Perplexity on evaluation set"""
        total_loss = 0.0
        num_tokens = 0
        
        self.student.eval()
        with torch.no_grad():
            for text in texts:
                inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
                outputs = self.student(**inputs)
                loss = outputs.loss if hasattr(outputs, 'loss') else 0.0
                
                if loss > 0:
                    total_loss += loss.item()
                    num_tokens += inputs['input_ids'].numel()
        
        perplexity = math.exp(total_loss / num_tokens) if num_tokens > 0 else float('inf')
        return perplexity
    
    def compute_task_specific_metrics(self, dataset_name: str = "wikitext"):
        """Evaluate on specific tasks (QA, summarization, etc.)"""
        metrics = {}
        
        if dataset_name == "wikitext":
            dataset = load_dataset("wikitext", "wikitext-2")
            perplexity = self.compute_perplexity(dataset['test']['text'][:100])
            metrics['wikitext_perplexity'] = perplexity
        
        return metrics
    
    def distillation_fidelity(self, texts: list, top_k: int = 5) -> float:
        """Measure how well student matches teacher predictions"""
        match_count = 0
        total = 0
        
        self.teacher.eval()
        self.student.eval()
        
        with torch.no_grad():
            for text in texts:
                inputs = self.tokenizer(text, return_tensors='pt').to(self.device)
                
                teacher_logits = self.teacher(**inputs).logits
                student_logits = self.student(**inputs)['logits']
                
                # Top-k agreement
                teacher_topk = torch.topk(teacher_logits, top_k, dim=-1).indices
                student_topk = torch.topk(student_logits, top_k, dim=-1).indices
                
                match = (teacher_topk == student_topk).float().mean().item()
                match_count += match
                total += 1
        
        fidelity = match_count / total if total > 0 else 0.0
        return fidelity


# ============================================================================
# Training Command (with DeepSpeed)
# ============================================================================

"""
To train with DeepSpeed:

    deepspeed distill_llm.py \
        --deepspeed_config deepspeed_config.json \
        --teacher_model mistralai/Mistral-7B-Instruct-v0.1 \
        --student_hidden_dim 512 \
        --student_num_layers 8 \
        --batch_size 4 \
        --gradient_accumulation_steps 4 \
        --learning_rate 5e-4 \
        --max_steps 10000 \
        --temperature 4.0 \
        --alpha 0.7 \
        --beta 0.3

For RTX 2050 (4GB VRAM):
- Use ZeRO-2 with CPU offloading
- Batch size: 4 per GPU (with 4x accumulation)
- fp16 training
- Gradient checkpointing
- INT8 quantization after training (8x compression)

Estimated memory:
- Teacher: 14GB (load with device_map='auto' to split)
- Student: 1.2GB (fp16)
- Optimizer states: 2.4GB (offloaded to CPU)
- Gradients: 1.2GB
- Activations: 0.5GB
- Total on GPU: ~3.5GB ✓ (fits in 4GB)
"""