import lighteval from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.models.vllm.vllm_model import VLLMModelConfig from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters from lighteval.utils.imports import is_package_available from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer import os import torch if is_package_available("accelerate"): from datetime import timedelta from accelerate import Accelerator, InitProcessGroupKwargs accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))]) else: accelerator = None def merge_lora_if_needed(): """Merge LoRA model if merged version doesn't exist""" merged_path = "/public/home/lshi/yoAI/projects/Online_CL/train/model_sft_save/Qwen2.5-Math-1.5B-DeepScaleR-Merged" # Check if merged model already exists if os.path.exists(os.path.join(merged_path, "config.json")): print(f"Merged model already exists at {merged_path}") return merged_path print("="*100) print("Merged model not found. Starting merge process...") print("="*100) print("\n[1/5] Loading base model...") base_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-Math-1.5B", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) print("\n[2/5] Loading LoRA adapter...") lora_path = "/public/home/lshi/yoAI/projects/Online_CL/train/model_sft_save/Qwen2.5-Math-1.5B-DeepScaleR-Lora/checkpoint-2834" model = PeftModel.from_pretrained(base_model, lora_path) print("\n[3/5] Merging LoRA weights with base model...") merged_model = model.merge_and_unload() print(f"\n[4/5] Saving merged model to {merged_path}...") os.makedirs(merged_path, exist_ok=True) merged_model.save_pretrained(merged_path, safe_serialization=True) print("\n[5/5] Saving tokenizer...") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B", trust_remote_code=True) tokenizer.save_pretrained(merged_path) print("\n" + "="*100) print("✓ Merge completed successfully!") print(f"✓ Merged model saved to: {merged_path}") print("="*100 + "\n") # Clean up to free memory before evaluation del base_model del model del merged_model torch.cuda.empty_cache() return merged_path def main(): # ===== SET CUDA_VISIBLE_DEVICES FIRST (BEFORE ANY TORCH OPERATIONS) ===== # Option 1: Use only GPU 2 os.environ["CUDA_VISIBLE_DEVICES"] = "2" # First, ensure the merged model exists print("Checking for merged model...") merged_model_path = merge_lora_if_needed() # ===== DETECT NUMBER OF GPUs AFTER SETTING CUDA_VISIBLE_DEVICES ===== num_gpus = torch.cuda.device_count() print(f"\n{'='*100}") print(f"Detected {num_gpus} GPU(s) (after CUDA_VISIBLE_DEVICES filtering)") if num_gpus > 0: for i in range(num_gpus): print(f" GPU {i}: {torch.cuda.get_device_name(i)}") print(f"{'='*100}\n") print("Setting up evaluation pipeline...") evaluation_tracker = EvaluationTracker( output_dir="./results", save_details=True, push_to_hub=False, # hub_results_org="your_username", ) pipeline_params = PipelineParameters( launcher_type=ParallelismManager.ACCELERATE, custom_tasks_directory=None, max_samples=500 ) model_config = VLLMModelConfig( model_name=merged_model_path, dtype="bfloat16", max_model_length=4096, trust_remote_code=True, tensor_parallel_size=num_gpus, # This will now correctly use only visible GPUs ) task = "lighteval|math_500|0" # aime24 aime24_gpassk print(f"Using {num_gpus} GPU(s) with tensor parallelism") print(f"Task: {task}\n") print("Creating pipeline...") pipeline = Pipeline( tasks=task, pipeline_parameters=pipeline_params, evaluation_tracker=evaluation_tracker, model_config=model_config, ) # Fix generation_size print("Configuring generation parameters...") for task_name, task_obj in pipeline.tasks_dict.items(): for doc in task_obj._docs: doc.generation_size = 2048 print("\nStarting evaluation...") print("="*100) pipeline.evaluate() print("\nSaving results...") pipeline.save_and_push_results() print("\nShowing results...") pipeline.show_results() print("\n" + "="*100) print("✓ Evaluation completed!") print("="*100) if __name__ == "__main__": main()