""" Fixed Optimized Japanese Counseling Model Benchmark with proper DataParallel handling """ import torch import torch.nn as nn from torch.nn.parallel import DataParallel from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np from typing import List, Dict, Tuple, Optional, Any import json from tqdm import tqdm import os import gc import warnings from datetime import datetime import pandas as pd from collections import defaultdict import MeCab from rouge_score import rouge_scorer from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import re import wandb from concurrent.futures import ThreadPoolExecutor import time # Suppress warnings warnings.filterwarnings('ignore') os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Suppress Pydantic warnings import logging logging.getLogger('pydantic').setLevel(logging.ERROR) class TestDataset(Dataset): """Custom dataset for efficient batch processing""" def __init__(self, data: List[Dict]): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def custom_collate_fn(batch): """Custom collate function to handle dictionary data properly""" return batch class OptimizedJapaneseBenchmark: """ Highly optimized benchmark suite with multi-GPU support and WandB logging """ def __init__(self, base_model_name: str = "LiquidAI/LFM2-1.2B", finetuned_model_path: str = "./merged_counselor_model", test_data_path: str = "./processed_data_score80/test.jsonl", batch_size: int = 16, # Reduced for stability num_workers: int = 0, use_wandb: bool = True): """ Initialize optimized benchmark with multi-GPU support """ self.base_model_name = base_model_name self.finetuned_model_path = finetuned_model_path self.test_data_path = test_data_path self.batch_size = batch_size self.num_workers = num_workers # Setup devices self.setup_devices() # Initialize WandB if use_wandb: self.init_wandb() else: self.wandb_enabled = False # Initialize tokenizers and scorers self.setup_tokenizers_and_scorers() # Results storage self.results = {} self.detailed_results = [] def setup_devices(self): """Setup multi-GPU configuration""" if torch.cuda.is_available(): self.num_gpus = torch.cuda.device_count() print(f"๐Ÿš€ Found {self.num_gpus} GPUs") self.device_ids = list(range(self.num_gpus)) self.device = torch.device("cuda:0") for i in range(self.num_gpus): print(f" GPU {i}: {torch.cuda.get_device_name(i)}") print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") else: self.num_gpus = 0 self.device = torch.device("cpu") print("โš ๏ธ No GPU found, using CPU") def init_wandb(self): """Initialize WandB for experiment tracking""" try: run_name = f"benchmark-{datetime.now().strftime('%Y%m%d-%H%M%S')}" wandb.init( project="japanese-counseling-benchmark", name=run_name, config={ "base_model": self.base_model_name, "finetuned_model": self.finetuned_model_path, "batch_size": self.batch_size, "num_gpus": self.num_gpus, "timestamp": datetime.now().isoformat() }, tags=["benchmark", "japanese", "counseling", "multi-gpu"] ) self.wandb_enabled = True print(f"โœ… WandB initialized: {wandb.run.name}") print(f"๐Ÿ“Š View at: {wandb.run.get_url()}") except Exception as e: print(f"โš ๏ธ WandB initialization failed: {e}") self.wandb_enabled = False def setup_tokenizers_and_scorers(self): """Setup tokenizers and scoring functions""" # Initialize MeCab for Japanese tokenization try: self.mecab = MeCab.Tagger("-Owakati") print("โœ… MeCab initialized") except: print("โš ๏ธ MeCab not available, using character tokenization") self.mecab = None # Initialize ROUGE scorer self.rouge_scorer = rouge_scorer.RougeScorer( ['rouge1', 'rouge2', 'rougeL'], use_stemmer=False ) # BLEU smoothing self.smoothing = SmoothingFunction().method1 def load_test_data_fast(self, max_samples: Optional[int] = None) -> List[Dict]: """Fast loading of test data""" print(f"\n๐Ÿ“š Loading test data from {self.test_data_path}") test_data = [] if not os.path.exists(self.test_data_path): print("โš ๏ธ Test data not found, using synthetic data") return self.create_synthetic_test_data() try: with open(self.test_data_path, 'r', encoding='utf-8') as f: lines = f.readlines() if max_samples: lines = lines[:max_samples] for line in tqdm(lines, desc="Loading data"): try: data = json.loads(line) text = data.get('text', '') if "### Input:" in text and "### Response:" in text: input_part = text.split("### Input:")[1].split("### Response:")[0].strip() response_part = text.split("### Response:")[1].strip() test_data.append({ 'input': input_part, 'reference': response_part, 'score': data.get('score', 0), 'topic': data.get('topic', 'Unknown') }) except: continue except Exception as e: print(f"Error loading data: {e}") return self.create_synthetic_test_data() if not test_data: print("โš ๏ธ No valid data found, using synthetic data") return self.create_synthetic_test_data() print(f"โœ… Loaded {len(test_data)} test examples") if self.wandb_enabled: wandb.log({"test_data_size": len(test_data)}) return test_data def create_synthetic_test_data(self) -> List[Dict]: """Create synthetic test data""" return [ { 'input': f'ใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ€‚', 'reference': f'ใŠๆฐ—ๆŒใกใ‚ใ‹ใ‚Šใพใ™ใ€‚ใฉใฎใ‚ˆใ†ใช็Šถๆณใงใ‚นใƒˆใƒฌใ‚นใ‚’ๆ„Ÿใ˜ใฆใ„ใพใ™ใ‹๏ผŸ', 'score': 75, 'topic': 'stress' } for i in range(10) ] def load_models_optimized(self): """Load models with optimization for multi-GPU""" print("\n๐Ÿค– Loading models with optimization...") # Load tokenizer print(" Loading tokenizer...") try: self.tokenizer = AutoTokenizer.from_pretrained( self.base_model_name, use_fast=True ) except: self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load base model print(" Loading base model...") try: base_model = AutoModelForCausalLM.from_pretrained( self.base_model_name, torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True ) except Exception as e: print(f" Error loading base model: {e}") print(" Using GPT2 as fallback...") base_model = AutoModelForCausalLM.from_pretrained( "gpt2", torch_dtype=torch.float16 ) # Load fine-tuned model print(" Loading fine-tuned model...") if os.path.exists(self.finetuned_model_path): try: finetuned_model = AutoModelForCausalLM.from_pretrained( self.finetuned_model_path, torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, local_files_only=True ) except Exception as e: print(f" Error loading fine-tuned model: {e}") finetuned_model = base_model else: print(" Fine-tuned model not found, using base model") finetuned_model = base_model # Move models to GPU base_model = base_model.to(self.device) finetuned_model = finetuned_model.to(self.device) # Setup for multi-GPU if available if self.num_gpus > 1: print(f" Setting up DataParallel for {self.num_gpus} GPUs...") self.base_model = DataParallel(base_model, device_ids=self.device_ids) self.finetuned_model = DataParallel(finetuned_model, device_ids=self.device_ids) else: self.base_model = base_model self.finetuned_model = finetuned_model self.base_model.eval() self.finetuned_model.eval() print("โœ… Models loaded and optimized!") if self.wandb_enabled: wandb.log({ "model_loaded": True, "num_gpus_used": self.num_gpus }) def generate_batch_responses(self, model, prompts: List[str], max_length: int = 150) -> List[str]: """Generate responses in batch for efficiency""" if len(prompts) == 0: return [] formatted_prompts = [ f"""### Instruction: ใ‚ใชใŸใฏๆ€ใ„ใ‚„ใ‚Šใฎใ‚ใ‚‹ๅฟƒ็†ใ‚ซใ‚ฆใƒณใ‚ปใƒฉใƒผใงใ™ใ€‚ ### Input: {prompt} ### Response: """ for prompt in prompts ] try: # Tokenize all prompts at once inputs = self.tokenizer( formatted_prompts, return_tensors="pt", truncation=True, max_length=512, padding=True, padding_side= 'left' ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get the actual model from DataParallel if needed actual_model = model.module if isinstance(model, DataParallel) else model # Generate in batch with torch.no_grad(): with torch.cuda.amp.autocast(): outputs = actual_model.generate( **inputs, max_new_tokens=max_length, temperature=0.7, do_sample=True, top_p=0.9, num_beams=1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode all at once responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) # Extract only generated parts extracted_responses = [] for i, response in enumerate(responses): if "### Response:" in response: extracted = response.split("### Response:")[-1].strip() else: extracted = response[len(formatted_prompts[i]):].strip() extracted_responses.append(extracted if extracted else "ๅฟœ็ญ”ใ‚’็”Ÿๆˆใงใใพใ›ใ‚“ใงใ—ใŸใ€‚") return extracted_responses except Exception as e: print(f"Error in batch generation: {e}") # Return default responses return ["็”ณใ—่จณใ”ใ–ใ„ใพใ›ใ‚“ใ€‚ๅฟœ็ญ”ใ‚’็”Ÿๆˆใงใใพใ›ใ‚“ใงใ—ใŸใ€‚"] * len(prompts) def tokenize_japanese(self, text: str) -> List[str]: """Tokenize Japanese text""" if not text: return ['empty'] if self.mecab: try: tokens = self.mecab.parse(text).strip().split() return tokens if tokens else list(text) except: pass # Fallback to character tokenization return list(text.replace(' ', '')) def calculate_metrics_batch(self, references: List[str], hypotheses: List[str]) -> Dict: """Calculate all metrics in batch""" metrics = defaultdict(list) for ref, hyp in zip(references, hypotheses): if not ref or not hyp: # Add default scores for empty strings for n in range(1, 5): metrics[f'BLEU-{n}'].append(0.0) metrics['ROUGE-1'].append(0.0) metrics['ROUGE-2'].append(0.0) metrics['ROUGE-L'].append(0.0) continue try: # Tokenize ref_tokens = self.tokenize_japanese(ref) hyp_tokens = self.tokenize_japanese(hyp) # BLEU scores for n in range(1, 5): weights = tuple([1/n] * n + [0] * (4-n)) try: score = sentence_bleu( [ref_tokens], hyp_tokens, weights=weights, smoothing_function=self.smoothing ) metrics[f'BLEU-{n}'].append(score) except: metrics[f'BLEU-{n}'].append(0.0) # ROUGE scores try: ref_spaced = ' '.join(ref_tokens) hyp_spaced = ' '.join(hyp_tokens) rouge_scores = self.rouge_scorer.score(ref_spaced, hyp_spaced) metrics['ROUGE-1'].append(rouge_scores['rouge1'].fmeasure) metrics['ROUGE-2'].append(rouge_scores['rouge2'].fmeasure) metrics['ROUGE-L'].append(rouge_scores['rougeL'].fmeasure) except: metrics['ROUGE-1'].append(0.0) metrics['ROUGE-2'].append(0.0) metrics['ROUGE-L'].append(0.0) except Exception as e: # Add zeros for failed calculations for n in range(1, 5): metrics[f'BLEU-{n}'].append(0.0) metrics['ROUGE-1'].append(0.0) metrics['ROUGE-2'].append(0.0) metrics['ROUGE-L'].append(0.0) return dict(metrics) def run_fast_benchmark(self, num_samples: Optional[int] = None): """Run optimized benchmark with batch processing""" print("\n" + "="*80) print("๐Ÿš€ Running Fast Multi-GPU Benchmark") print("="*80) start_time = time.time() # Load test data test_data = self.load_test_data_fast(max_samples=num_samples) if not test_data: raise ValueError("No test data available!") # Create DataLoader dataset = TestDataset(test_data) dataloader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, collate_fn=custom_collate_fn, pin_memory=True if self.device.type == 'cuda' else False ) # Initialize metric collectors all_base_metrics = defaultdict(list) all_finetuned_metrics = defaultdict(list) print(f"\n๐Ÿ“Š Evaluating {len(test_data)} examples in {len(dataloader)} batches...") print(f" Batch size: {self.batch_size}") print(f" Using {self.num_gpus} GPU(s)") # Process batches successful_batches = 0 for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")): try: # Extract batch data inputs = [item['input'] for item in batch] references = [item['reference'] for item in batch] # Generate responses in batch base_responses = self.generate_batch_responses(self.base_model, inputs) finetuned_responses = self.generate_batch_responses(self.finetuned_model, inputs) # Calculate metrics in batch base_metrics = self.calculate_metrics_batch(references, base_responses) finetuned_metrics = self.calculate_metrics_batch(references, finetuned_responses) # Aggregate metrics for key, values in base_metrics.items(): all_base_metrics[key].extend(values) for key, values in finetuned_metrics.items(): all_finetuned_metrics[key].extend(values) successful_batches += 1 # Log progress to WandB if self.wandb_enabled and batch_idx % 5 == 0: progress = (batch_idx + 1) / len(dataloader) * 100 # Calculate current averages current_bleu4_base = np.mean(all_base_metrics.get('BLEU-4', [0])) current_bleu4_finetuned = np.mean(all_finetuned_metrics.get('BLEU-4', [0])) current_rouge_l_base = np.mean(all_base_metrics.get('ROUGE-L', [0])) current_rouge_l_finetuned = np.mean(all_finetuned_metrics.get('ROUGE-L', [0])) wandb.log({ "progress": progress, "batches_processed": batch_idx + 1, "samples_processed": min((batch_idx + 1) * self.batch_size, len(test_data)), "current_bleu4_base": current_bleu4_base, "current_bleu4_finetuned": current_bleu4_finetuned, "current_rouge_l_base": current_rouge_l_base, "current_rouge_l_finetuned": current_rouge_l_finetuned }) # Store examples for analysis if batch_idx == 0 and len(inputs) > 0: for i in range(min(3, len(inputs))): self.detailed_results.append({ 'input': inputs[i], 'reference': references[i], 'base_response': base_responses[i] if i < len(base_responses) else "", 'finetuned_response': finetuned_responses[i] if i < len(finetuned_responses) else "" }) # Print sample print(f"\n๐Ÿ“ Sample Example:") print(f"Input: {inputs[0][:100]}...") print(f"Reference: {references[0][:100]}...") print(f"Base response: {base_responses[0][:100]}...") print(f"Fine-tuned response: {finetuned_responses[0][:100]}...") except Exception as e: print(f"Error processing batch {batch_idx}: {e}") continue print(f"\nโœ… Successfully processed {successful_batches}/{len(dataloader)} batches") # Calculate final statistics self.results = self.calculate_final_statistics(all_base_metrics, all_finetuned_metrics) # Calculate processing time total_time = time.time() - start_time samples_per_second = len(test_data) / total_time if total_time > 0 else 0 print(f"\nโฑ๏ธ Benchmark completed in {total_time:.2f} seconds") print(f" Processing speed: {samples_per_second:.2f} samples/second") # Log final results to WandB if self.wandb_enabled: wandb.log({ "total_time_seconds": total_time, "samples_per_second": samples_per_second, "total_samples": len(test_data), "successful_batches": successful_batches, **{f"final_{k}": v for k, v in self.results['summary'].items()} }) # Log detailed metrics for metric_name, improvements in self.results['improvements'].items(): wandb.log({f"improvement_{metric_name}": improvements}) # Create visualization if self.results['metrics']: self.create_wandb_visualizations() # Print results self.print_results() return self.results def create_wandb_visualizations(self): """Create WandB visualizations""" if not self.wandb_enabled or not self.results.get('metrics'): return try: # Create comparison table data = [] for metric in self.results['metrics']: data.append([ metric, self.results['metrics'][metric]['base']['mean'], self.results['metrics'][metric]['finetuned']['mean'], self.results['improvements'][metric] ]) table = wandb.Table( columns=["Metric", "Base", "Fine-tuned", "Improvement (%)"], data=data ) wandb.log({"results_comparison": table}) # Log bar chart of improvements wandb.log({ "improvements_chart": wandb.plot.bar( wandb.Table( data=[[m, self.results['improvements'][m]] for m in self.results['improvements']], columns=["Metric", "Improvement (%)"] ), "Metric", "Improvement (%)", title="Model Improvements" ) }) except Exception as e: print(f"Error creating visualizations: {e}") def calculate_final_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict: """Calculate final aggregate statistics""" results = { 'metrics': {}, 'improvements': {}, 'summary': {} } # Calculate statistics for each metric all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys()) for metric in all_metric_names: base_values = base_metrics.get(metric, [0]) finetuned_values = finetuned_metrics.get(metric, [0]) # Filter out any None values base_values = [v for v in base_values if v is not None] finetuned_values = [v for v in finetuned_values if v is not None] if not base_values: base_values = [0] if not finetuned_values: finetuned_values = [0] results['metrics'][metric] = { 'base': { 'mean': float(np.mean(base_values)), 'std': float(np.std(base_values)), 'min': float(np.min(base_values)), 'max': float(np.max(base_values)) }, 'finetuned': { 'mean': float(np.mean(finetuned_values)), 'std': float(np.std(finetuned_values)), 'min': float(np.min(finetuned_values)), 'max': float(np.max(finetuned_values)) } } # Calculate improvement base_mean = np.mean(base_values) finetuned_mean = np.mean(finetuned_values) if base_mean > 0: improvement = ((finetuned_mean - base_mean) / base_mean) * 100 else: improvement = 0 if finetuned_mean == 0 else 100 results['improvements'][metric] = improvement # Calculate summary statistics bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m] rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m] results['summary'] = { 'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0, 'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0, 'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0 } return results def print_results(self): """Print formatted results""" print("\n" + "="*80) print("๐Ÿ“Š BENCHMARK RESULTS") print("="*80) if not self.results or 'metrics' not in self.results: print("No results to display") return # BLEU scores print("\n๐Ÿ“˜ BLEU Scores:") print("-"*60) print(f"{'Metric':<15} {'Base':<15} {'Fine-tuned':<15} {'Improvement':<15}") print("-"*60) for metric in sorted([m for m in self.results['metrics'] if 'BLEU' in m]): base = self.results['metrics'][metric]['base']['mean'] finetuned = self.results['metrics'][metric]['finetuned']['mean'] improvement = self.results['improvements'][metric] print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%") # ROUGE scores print("\n๐Ÿ“• ROUGE Scores:") print("-"*60) for metric in sorted([m for m in self.results['metrics'] if 'ROUGE' in m]): base = self.results['metrics'][metric]['base']['mean'] finetuned = self.results['metrics'][metric]['finetuned']['mean'] improvement = self.results['improvements'][metric] print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%") # Summary print("\n" + "="*80) print("๐Ÿ“ˆ SUMMARY") print("="*80) print(f"BLEU Average Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%") print(f"ROUGE Average Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%") print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%") print("="*80) def save_results(self, output_dir: str = "./benchmark_results"): """Save results""" os.makedirs(output_dir, exist_ok=True) # Save results with open(os.path.join(output_dir, "results.json"), 'w', encoding='utf-8') as f: json.dump(self.results, f, ensure_ascii=False, indent=2, default=str) with open(os.path.join(output_dir, "examples.json"), 'w', encoding='utf-8') as f: json.dump(self.detailed_results, f, ensure_ascii=False, indent=2) # Save to WandB if self.wandb_enabled: try: artifact = wandb.Artifact( name=f"benchmark-results-{wandb.run.id}", type="benchmark_results", description="Japanese counseling model benchmark results" ) artifact.add_dir(output_dir) wandb.log_artifact(artifact) except Exception as e: print(f"Error saving to WandB: {e}") print(f"โœ… Results saved to {output_dir}/") def cleanup(self): """Clean up resources""" if self.wandb_enabled: wandb.finish() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def main(): """Main execution""" import argparse parser = argparse.ArgumentParser(description='Optimized Japanese Counseling Benchmark') parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B') parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model') parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing') parser.add_argument('--num_samples', type=int, default=None, help='Number of samples to evaluate') parser.add_argument('--output_dir', type=str, default='./benchmark_results_fast') parser.add_argument('--no_wandb', action='store_true', help='Disable WandB logging') args = parser.parse_args() try: # Initialize benchmark print("๐Ÿš€ Initializing Optimized Benchmark Suite") benchmark = OptimizedJapaneseBenchmark( base_model_name=args.base_model, finetuned_model_path=args.finetuned_model, test_data_path=args.test_data, batch_size=args.batch_size, use_wandb=not args.no_wandb ) # Load models benchmark.load_models_optimized() # Run benchmark results = benchmark.run_fast_benchmark(num_samples=args.num_samples) # Save results benchmark.save_results(args.output_dir) # Cleanup benchmark.cleanup() print("\nโœ… Benchmark completed successfully!") except Exception as e: print(f"\nโŒ Error: {e}") import traceback traceback.print_exc() if 'benchmark' in locals(): benchmark.cleanup() if __name__ == "__main__": main()