File size: 9,276 Bytes
82ea551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
#!/usr/bin/env python3
# ==========================================================
# High-speed multi-GPU evaluation for GLM-4.5-Air-HS adapters
# Uses 4Γ—H200 for maximum throughput.
# ==========================================================

import os, json, math, torch, time
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, barrier, destroy_process_group
import torch.distributed as dist


# ---------------- CONFIG ----------------
BASE_MODEL = "/workspace/Avinash/models/GLM-4.5-Air"
CHECKPOINT_DIR = "checkpoints"
DATA_PATH = "/workspace/Avinash/dataset/all_data.jsonl"
OUTPUT_PATH = "eval_scores.json"
MAX_SAMPLES = 1000       # subset for eval speed
BATCH_SIZE = 2           # safe for 80GB H200
SEQ_LEN = 2048
DTYPE = torch.bfloat16   # use bf16 for H200
# ----------------------------------------


class CodeDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=2048):
        self.samples = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text = self.samples[idx]["text"]
        tokens = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )
        return tokens["input_ids"][0]


def collate_fn(batch, pad_token_id=0):
    """Pad variable-length sequences and build attention masks and labels."""
    lengths = [seq.size(0) for seq in batch]
    max_len = max(lengths)

    input_ids = []
    attention_masks = []

    for seq, seq_len in zip(batch, lengths):
        if seq_len < max_len:
            padding = torch.full((max_len - seq_len,), pad_token_id, dtype=seq.dtype)
            padded_seq = torch.cat([seq, padding], dim=0)
        else:
            padded_seq = seq
        mask = torch.zeros(max_len, dtype=torch.long)
        mask[:seq_len] = 1
        input_ids.append(padded_seq)
        attention_masks.append(mask)

    input_ids = torch.stack(input_ids, dim=0)
    attention_mask = torch.stack(attention_masks, dim=0)
    labels = input_ids.clone()
    labels[attention_mask == 0] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


def load_subset(path, limit=MAX_SAMPLES):
    data = []
    with open(path, "r") as f:
        for i, line in enumerate(f):
            if i >= limit:
                break
            try:
                data.append(json.loads(line))
            except Exception:
                continue
    return data


def evaluate_checkpoint(ckpt_path, subset, rank, local_rank, world_size):
    """Evaluate one checkpoint - only rank 0 loads the model with device_map='auto'."""
    
    if rank == 0:
        print(f"\nπŸš€ Evaluating {ckpt_path} on {world_size} GPUs", flush=True)
        print(f"πŸ“₯ Loading base model with device_map='auto'...", flush=True)
        
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
        if tokenizer.pad_token_id is None:
            if tokenizer.eos_token is None:
                raise ValueError("Tokenizer needs a pad_token or eos_token for batching.")
            tokenizer.pad_token = tokenizer.eos_token
        
        # Load model with automatic device mapping across all GPUs
        base = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.bfloat16,
            device_map="auto",  # Automatically shard across all GPUs
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )

        print(f"πŸ”§ Loading adapter from {ckpt_path}...", flush=True)
        model = PeftModel.from_pretrained(base, ckpt_path)
        model.eval()

        print(f"πŸ“Š Creating dataset and dataloader...", flush=True)
        dataset = CodeDataset(subset, tokenizer, max_len=SEQ_LEN)
        
        # Get pad token id from tokenizer
        pad_token_id = tokenizer.pad_token_id
        
        # Create custom collate function with the correct pad_token_id
        def custom_collate(batch):
            return collate_fn(batch, pad_token_id=pad_token_id)
        
        loader = DataLoader(
            dataset, 
            batch_size=BATCH_SIZE, 
            shuffle=False,
            pin_memory=True,
            num_workers=0,
            collate_fn=custom_collate
        )

        total_loss = 0
        total_count = 0

        print(f"⚑ Starting evaluation...", flush=True)

        with torch.no_grad():
            for batch in tqdm(loader, ncols=100, desc="Evaluating"):
                # Move batch to first device (where model starts)
                first_device = next(model.parameters()).device
                batch = {k: v.to(first_device) for k, v in batch.items()}
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )
                loss = outputs.loss.detach()
                batch_size = batch["input_ids"].size(0)
                total_loss += loss.item() * batch_size
                total_count += batch_size

        avg_loss = total_loss / max(total_count, 1)
        ppl = math.exp(avg_loss)
        
        result = {
            "avg_loss": round(avg_loss, 4),
            "perplexity": round(ppl, 3)
        }
        
        print(f"βœ… {os.path.basename(ckpt_path)}: loss={avg_loss:.4f}, ppl={ppl:.2f}", flush=True)
        
        # Clean up to free memory
        del loader
        del dataset
        del model
        del base
        del tokenizer
        
        # Force garbage collection and clear CUDA cache
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        return result
    else:
        # Other ranks just wait
        return None


def main():
    # Initialize process group (torchrun sets the environment variables)
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    
    # Set device BEFORE initializing process group
    torch.cuda.set_device(local_rank)
    
    # Initialize distributed training
    if not dist.is_initialized():
        init_process_group(backend="nccl")
    
    if rank == 0:
        print("πŸ” Loading subset of dataset...", flush=True)
    
    subset = load_subset(DATA_PATH)
    
    if rank == 0:
        print(f"Loaded {len(subset)} samples.", flush=True)

    # Find specific checkpoints to evaluate
    if rank == 0:
        target_checkpoints = ["checkpoint-5000", "checkpoint-6000", "checkpoint-7000", "final-checkpoint"]
        checkpoints = []
        for ckpt_name in target_checkpoints:
            ckpt_path = os.path.join(CHECKPOINT_DIR, ckpt_name)
            if os.path.isdir(ckpt_path):
                checkpoints.append(ckpt_path)
            else:
                print(f"⚠️  Warning: {ckpt_name} not found", flush=True)
        
        if not checkpoints:
            print(f"⚠️  No target checkpoints found in {CHECKPOINT_DIR}", flush=True)
            destroy_process_group()
            return
        print(f"πŸ“ Found {len(checkpoints)} checkpoints to evaluate", flush=True)
        print(f"πŸ“‹ Checkpoints: {checkpoints}", flush=True)
    else:
        checkpoints = None
    
    # Synchronize before broadcast
    if rank == 0:
        print("πŸ”„ Broadcasting checkpoint list to all ranks...", flush=True)
    dist.barrier()
    
    # Broadcast checkpoint list to all ranks
    if world_size > 1:
        if rank == 0:
            checkpoint_obj = [checkpoints]
        else:
            checkpoint_obj = [None]
        dist.broadcast_object_list(checkpoint_obj, src=0)
        checkpoints = checkpoint_obj[0]
    
    if rank == 0:
        print(f"βœ… All ranks have checkpoint list", flush=True)

    all_results = {}
    start_time = time.time()

    for ckpt in checkpoints:
        result = evaluate_checkpoint(ckpt, subset, rank, local_rank, world_size)
        
        # Only rank 0 saves results
        if rank == 0 and result is not None:
            ckpt_name = os.path.basename(ckpt)
            all_results[ckpt_name] = result
            
            # Save interim results
            with open(OUTPUT_PATH, "w") as f:
                json.dump(all_results, f, indent=2)
            print(f"πŸ’Ύ Interim results saved to {OUTPUT_PATH}", flush=True)

    if rank == 0:
        total_mins = (time.time() - start_time) / 60
        print(f"\n🏁 All evaluations done in {total_mins:.1f} min.")
        print(f"πŸ“Š Final results saved at {OUTPUT_PATH}")
        print("\nπŸ“ˆ Results sorted by perplexity:")
        for ckpt_name, metrics in sorted(all_results.items(), key=lambda x: x[1]["perplexity"]):
            print(f"  {ckpt_name}: loss={metrics['avg_loss']}, ppl={metrics['perplexity']}")

    # Clean up
    destroy_process_group()


if __name__ == "__main__":
    main()