#!/usr/bin/env python3 # ========================================================== # High-speed multi-GPU evaluation for GLM-4.5-Air-HS adapters # Uses 4ƗH200 for maximum throughput. # ========================================================== import os, json, math, torch, time from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from torch.utils.data import DataLoader, Dataset from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, barrier, destroy_process_group import torch.distributed as dist # ---------------- CONFIG ---------------- BASE_MODEL = "/workspace/Avinash/models/GLM-4.5-Air" CHECKPOINT_DIR = "checkpoints" DATA_PATH = "/workspace/Avinash/dataset/all_data.jsonl" OUTPUT_PATH = "eval_scores.json" MAX_SAMPLES = 1000 # subset for eval speed BATCH_SIZE = 2 # safe for 80GB H200 SEQ_LEN = 2048 DTYPE = torch.bfloat16 # use bf16 for H200 # ---------------------------------------- class CodeDataset(Dataset): def __init__(self, data, tokenizer, max_len=2048): self.samples = data self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.samples) def __getitem__(self, idx): text = self.samples[idx]["text"] tokens = self.tokenizer( text, truncation=True, max_length=self.max_len, return_tensors="pt" ) return tokens["input_ids"][0] def collate_fn(batch, pad_token_id=0): """Pad variable-length sequences and build attention masks and labels.""" lengths = [seq.size(0) for seq in batch] max_len = max(lengths) input_ids = [] attention_masks = [] for seq, seq_len in zip(batch, lengths): if seq_len < max_len: padding = torch.full((max_len - seq_len,), pad_token_id, dtype=seq.dtype) padded_seq = torch.cat([seq, padding], dim=0) else: padded_seq = seq mask = torch.zeros(max_len, dtype=torch.long) mask[:seq_len] = 1 input_ids.append(padded_seq) attention_masks.append(mask) input_ids = torch.stack(input_ids, dim=0) attention_mask = torch.stack(attention_masks, dim=0) labels = input_ids.clone() labels[attention_mask == 0] = -100 return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } def load_subset(path, limit=MAX_SAMPLES): data = [] with open(path, "r") as f: for i, line in enumerate(f): if i >= limit: break try: data.append(json.loads(line)) except Exception: continue return data def evaluate_checkpoint(ckpt_path, subset, rank, local_rank, world_size): """Evaluate one checkpoint - only rank 0 loads the model with device_map='auto'.""" if rank == 0: print(f"\nšŸš€ Evaluating {ckpt_path} on {world_size} GPUs", flush=True) print(f"šŸ“„ Loading base model with device_map='auto'...", flush=True) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token_id is None: if tokenizer.eos_token is None: raise ValueError("Tokenizer needs a pad_token or eos_token for batching.") tokenizer.pad_token = tokenizer.eos_token # Load model with automatic device mapping across all GPUs base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", # Automatically shard across all GPUs low_cpu_mem_usage=True, trust_remote_code=True ) print(f"šŸ”§ Loading adapter from {ckpt_path}...", flush=True) model = PeftModel.from_pretrained(base, ckpt_path) model.eval() print(f"šŸ“Š Creating dataset and dataloader...", flush=True) dataset = CodeDataset(subset, tokenizer, max_len=SEQ_LEN) # Get pad token id from tokenizer pad_token_id = tokenizer.pad_token_id # Create custom collate function with the correct pad_token_id def custom_collate(batch): return collate_fn(batch, pad_token_id=pad_token_id) loader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=0, collate_fn=custom_collate ) total_loss = 0 total_count = 0 print(f"⚔ Starting evaluation...", flush=True) with torch.no_grad(): for batch in tqdm(loader, ncols=100, desc="Evaluating"): # Move batch to first device (where model starts) first_device = next(model.parameters()).device batch = {k: v.to(first_device) for k, v in batch.items()} outputs = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] ) loss = outputs.loss.detach() batch_size = batch["input_ids"].size(0) total_loss += loss.item() * batch_size total_count += batch_size avg_loss = total_loss / max(total_count, 1) ppl = math.exp(avg_loss) result = { "avg_loss": round(avg_loss, 4), "perplexity": round(ppl, 3) } print(f"āœ… {os.path.basename(ckpt_path)}: loss={avg_loss:.4f}, ppl={ppl:.2f}", flush=True) # Clean up to free memory del loader del dataset del model del base del tokenizer # Force garbage collection and clear CUDA cache import gc gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() return result else: # Other ranks just wait return None def main(): # Initialize process group (torchrun sets the environment variables) rank = int(os.environ.get("RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) # Set device BEFORE initializing process group torch.cuda.set_device(local_rank) # Initialize distributed training if not dist.is_initialized(): init_process_group(backend="nccl") if rank == 0: print("šŸ” Loading subset of dataset...", flush=True) subset = load_subset(DATA_PATH) if rank == 0: print(f"Loaded {len(subset)} samples.", flush=True) # Find specific checkpoints to evaluate if rank == 0: target_checkpoints = ["checkpoint-5000", "checkpoint-6000", "checkpoint-7000", "final-checkpoint"] checkpoints = [] for ckpt_name in target_checkpoints: ckpt_path = os.path.join(CHECKPOINT_DIR, ckpt_name) if os.path.isdir(ckpt_path): checkpoints.append(ckpt_path) else: print(f"āš ļø Warning: {ckpt_name} not found", flush=True) if not checkpoints: print(f"āš ļø No target checkpoints found in {CHECKPOINT_DIR}", flush=True) destroy_process_group() return print(f"šŸ“ Found {len(checkpoints)} checkpoints to evaluate", flush=True) print(f"šŸ“‹ Checkpoints: {checkpoints}", flush=True) else: checkpoints = None # Synchronize before broadcast if rank == 0: print("šŸ”„ Broadcasting checkpoint list to all ranks...", flush=True) dist.barrier() # Broadcast checkpoint list to all ranks if world_size > 1: if rank == 0: checkpoint_obj = [checkpoints] else: checkpoint_obj = [None] dist.broadcast_object_list(checkpoint_obj, src=0) checkpoints = checkpoint_obj[0] if rank == 0: print(f"āœ… All ranks have checkpoint list", flush=True) all_results = {} start_time = time.time() for ckpt in checkpoints: result = evaluate_checkpoint(ckpt, subset, rank, local_rank, world_size) # Only rank 0 saves results if rank == 0 and result is not None: ckpt_name = os.path.basename(ckpt) all_results[ckpt_name] = result # Save interim results with open(OUTPUT_PATH, "w") as f: json.dump(all_results, f, indent=2) print(f"šŸ’¾ Interim results saved to {OUTPUT_PATH}", flush=True) if rank == 0: total_mins = (time.time() - start_time) / 60 print(f"\nšŸ All evaluations done in {total_mins:.1f} min.") print(f"šŸ“Š Final results saved at {OUTPUT_PATH}") print("\nšŸ“ˆ Results sorted by perplexity:") for ckpt_name, metrics in sorted(all_results.items(), key=lambda x: x[1]["perplexity"]): print(f" {ckpt_name}: loss={metrics['avg_loss']}, ppl={metrics['perplexity']}") # Clean up destroy_process_group() if __name__ == "__main__": main()