Spaces:
Runtime error
Runtime error
| """ | |
| Common evaluation utilities for benchmark tasks. | |
| This module provides shared functions for model evaluation across different benchmarks | |
| like MMLU-Redux and MMMLU. | |
| """ | |
| import re | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from rosetta.model.projector import load_projector | |
| from rosetta.model.wrapper import RosettaModel | |
| from rosetta.model.oracle import OracleRosettaModel | |
| def build_prompt(dataset: str, locale: str, question: str, choices: str, use_cot: bool, use_template: bool = True) -> str: | |
| """ | |
| Build a localized prompt for a given dataset and locale. | |
| Currently supports: | |
| - dataset: "mmmlu" | |
| - locale: "SW_KE" (Swahili). Other locales fall back to English. | |
| Args: | |
| dataset: Dataset identifier (e.g., "mmmlu") | |
| locale: Locale/subject code (e.g., "SW_KE") | |
| question: Question text | |
| choices: Formatted choices string | |
| use_cot: Whether to include CoT instruction | |
| Returns: | |
| Localized prompt string | |
| """ | |
| # Unified default English templates (shared by MMLU and MMMLU) | |
| if not use_cot: | |
| template = """Accurately answer the following question: | |
| {{question}} | |
| Choices: | |
| {{choices}} | |
| Instructions: | |
| - Carefully read the question and all options. | |
| - Select the single most correct answer. | |
| - Respond ONLY in the following format: "The correct answer is A/B/C/D". | |
| - Do not include any explanations, additional text, or punctuation besides the answer. | |
| The correct answer is""" | |
| else: | |
| template = """Accurately answer the following question: | |
| {{question}} | |
| Choices: | |
| {{choices}} | |
| Instructions: | |
| - Carefully read the question and all options. | |
| - Let's think step by step and explain your reasoning briefly. | |
| - Then give the final answer starting with The correct answer is""" | |
| prompt = template.replace("{{question}}", question) | |
| prompt = prompt.replace("{{choices}}", choices) | |
| if not use_template: | |
| prompt = question + "\n\nChoices:\n" + choices | |
| return prompt | |
| def parse_answer(answer_str: str) -> List[str]: | |
| """ | |
| Parse answer string to extract valid answer options. | |
| Converts digits 0/1/2/3 to letters A/B/C/D. | |
| Args: | |
| answer_str: String containing answer digits | |
| Returns: | |
| List of parsed answer letters | |
| """ | |
| if not isinstance(answer_str, str): | |
| return [] | |
| valid_digits = [c for c in answer_str if c in {'0','1','2','3'}] | |
| return sorted(list({ | |
| chr(65 + int(d)) # 0->A, 1->B, 2->C, 3->D | |
| for d in valid_digits | |
| })) | |
| def extract_answer_from_content(text: str) -> Optional[str]: | |
| """ | |
| Extract answer from model output with robust multi-pattern matching. | |
| Supports multiple languages and response formats. | |
| Args: | |
| text: Model output text | |
| Returns: | |
| Extracted answer letter or None | |
| """ | |
| text = text.strip() | |
| if not text: | |
| return None | |
| # Define multiple answer patterns for different languages and formats | |
| answer_patterns = [ | |
| # English patterns | |
| r'Answer:\s*(.*)', | |
| r'answer:\s*(.*)', | |
| r'ANSWER:\s*(.*)', | |
| r'Your answer:\s*(.*)', | |
| r'your answer:\s*(.*)', | |
| r'YOUR ANSWER:\s*(.*)', | |
| r'The answer is\s*(.*)', | |
| r'the answer is\s*(.*)', | |
| r'THE ANSWER IS\s*(.*)', | |
| r'Correct answer is\s*(.*)', | |
| r'correct answer is\s*(.*)', | |
| r'Correct answer is:\s*(.*)', | |
| r'correct answer is:\s*(.*)', | |
| r'Correct answer:\s*(.*)', | |
| r'correct answer:\s*(.*)', | |
| r'CORRECT ANSWER:\s*(.*)', | |
| # Swahili patterns | |
| r'Jibu lako:\s*(.*)', | |
| r'jibu lako:\s*(.*)', | |
| r'JIBU LAKO:\s*(.*)', | |
| r'Jibu:\s*(.*)', | |
| r'jibu:\s*(.*)', | |
| r'JIBU:\s*(.*)', | |
| r'Jibu sahihi:\s*(.*)', | |
| r'jibu sahihi:\s*(.*)', | |
| r'JIBU SAHIHI:\s*(.*)', | |
| # Other common patterns | |
| r'Response:\s*(.*)', | |
| r'response:\s*(.*)', | |
| r'RESPONSE:\s*(.*)', | |
| r'Choice:\s*(.*)', | |
| r'choice:\s*(.*)', | |
| r'CHOICE:\s*(.*)', | |
| r'Option:\s*(.*)', | |
| r'option:\s*(.*)', | |
| r'OPTION:\s*(.*)', | |
| ] | |
| # 1. Try to match any of the answer patterns | |
| for pattern in answer_patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| answer_part = match.group(1).strip() | |
| # Search for first A-D letter in the matched part | |
| for char in answer_part: | |
| if char in {'A', 'B', 'C', 'D'}: | |
| return char | |
| # 2. Look for standalone A-D letters that are likely answers | |
| # Prioritize letters at the end of text or with clear answer-like context | |
| standalone_patterns = [ | |
| r'\b([A-D])(?:\s*[.,!?:)]?\s*$)', # A-D at end of text with optional punctuation | |
| r'\b([A-D])(?:\s*[.,!?:)]\s)', # A-D followed by punctuation and space | |
| r'(?:^|\s)([A-D])(?:\s*$)', # A-D at start or with word boundary at end | |
| ] | |
| for pattern in standalone_patterns: | |
| matches = re.findall(pattern, text, re.IGNORECASE) | |
| if matches: | |
| # Check if this looks like mathematical expressions rather than answers | |
| math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan'] | |
| has_math = any(indicator in text for indicator in math_indicators) | |
| has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi']) | |
| # If it has math indicators but no answer indicators, it's likely mathematical notation | |
| if has_math and not has_answer_indicators: | |
| continue # Skip this match, try next pattern | |
| return matches[-1].upper() | |
| # 3. Fallback: find all A-D letters but be more selective | |
| all_letters = re.findall(r'\b([A-D])\b', text, re.IGNORECASE) | |
| if all_letters: | |
| # Check if this looks like mathematical expressions rather than answers | |
| math_indicators = ['+', '-', '*', '/', '=', '^', 'x^', 'y^', 'z^', 'mod', 'sqrt', 'sin', 'cos', 'tan'] | |
| has_math = any(indicator in text for indicator in math_indicators) | |
| has_answer_indicators = any(phrase in text.lower() for phrase in ['jibu', 'answer', 'choice', 'option', 'response', 'correct', 'sahihi']) | |
| # If it has math indicators but no answer indicators, it's likely mathematical notation | |
| if has_math and not has_answer_indicators: | |
| return None | |
| # Otherwise, return the last letter found | |
| return all_letters[-1].upper() | |
| # 3. Search backwards for any A-D letter as fallback | |
| for char in reversed(text): | |
| if char in {'A', 'B', 'C', 'D'}: | |
| return char | |
| return None | |
| def apply_generation_config(model: Any, generation_config: Optional[Dict[str, Any]] = None) -> None: | |
| """ | |
| Apply generation configuration to a model and handle sampling parameters. | |
| This function applies the provided generation config to the model and removes | |
| sampling parameters (temperature, top_p, top_k, min_p) when do_sample=False | |
| to avoid warnings from the transformers library. If no config is provided, | |
| it defaults to greedy decoding with cleaned sampling parameters. | |
| Args: | |
| model: Model object with generation_config attribute | |
| generation_config: Optional generation configuration dictionary. | |
| If None, defaults to greedy decoding (do_sample=False). | |
| """ | |
| if not hasattr(model, 'generation_config'): | |
| return | |
| # If no config provided, default to greedy decoding | |
| if not generation_config: | |
| generation_config = {'do_sample': False} | |
| # Apply all configuration parameters | |
| for key, value in generation_config.items(): | |
| setattr(model.generation_config, key, value) | |
| # Disable sampling parameters if do_sample=False to avoid warnings | |
| # We set them to None instead of deleting, since some model code may | |
| # access these attributes unconditionally. | |
| if not generation_config.get('do_sample', True): | |
| sampling_params = ['temperature', 'top_p', 'top_k', 'min_p', 'repetition_penalty'] | |
| for param in sampling_params: | |
| try: | |
| setattr(model.generation_config, param, None) | |
| except Exception: | |
| # If the backend does not allow setting, ignore silently | |
| pass | |
| def set_default_chat_template(tokenizer, model_name: str): | |
| """ | |
| Set default chat template for models without one. | |
| Args: | |
| tokenizer: Tokenizer object | |
| model_name: Name of the model | |
| """ | |
| if tokenizer.chat_template is None: | |
| if "UlizaLlama3".lower() in model_name.lower(): | |
| tokenizer.chat_template = ( | |
| "{%- for message in messages %}" | |
| "{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}" | |
| "{{- message['content'] }}" | |
| "{{- '<|eot_id|>' }}" | |
| "{%- endfor %}" | |
| "{%- if add_generation_prompt %}" | |
| "{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}" | |
| "{%- endif %}" | |
| ) | |
| else: | |
| print(f"Model {model_name} has no chat template, setting default template...") | |
| default_template = """{% for message in messages %}{% if message['role'] == 'user' %}### Human: {{ message['content'] }}{% elif message['role'] == 'assistant' %}### Assistant: {{ message['content'] }}{% endif %}{% if not loop.last %} | |
| {% endif %}{% endfor %}{% if add_generation_prompt %} | |
| ### Assistant:{% endif %}""" | |
| tokenizer.chat_template = default_template | |
| print("Default chat template has been set.") | |
| else: | |
| print(f"Model {model_name} already has a chat template.") | |
| def load_hf_model(model_name: str, device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]: | |
| """ | |
| Load Hugging Face model and tokenizer. | |
| Args: | |
| model_name: Model name or path | |
| device: Device to load model on | |
| generation_config: Optional generation configuration to apply | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| str(model_name), | |
| trust_remote_code=True, | |
| padding_side='left' | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Check and set chat template | |
| set_default_chat_template(tokenizer, model_name) | |
| if model_name == "google/gemma-3-1b-it": | |
| torch._dynamo.config.cache_size_limit = 64 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(model_name), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device}, | |
| sliding_window=4096 | |
| ).eval() | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| str(model_name), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device} | |
| ).eval() | |
| # Apply generation config | |
| apply_generation_config(model, generation_config) | |
| return model, tokenizer | |
| def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any], | |
| device: torch.device, generation_config: Optional[Dict[str, Any]] = None) -> Tuple[Any, Any]: | |
| """ | |
| Load Rosetta model with projectors and aggregators. | |
| Args: | |
| model_config: Model configuration dict | |
| eval_config: Evaluation configuration dict | |
| device: Device to load model on | |
| generation_config: Optional generation configuration to apply | |
| Returns: | |
| Tuple of (rosetta_model, tokenizer) | |
| """ | |
| # Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility | |
| rosetta_config = model_config["rosetta_config"] | |
| checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir")) | |
| if checkpoint_dir is None: | |
| raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)") | |
| slm_model_path = rosetta_config["base_model"] | |
| llm_model_path = rosetta_config["teacher_model"] | |
| # Load tokenizer | |
| slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path)) | |
| set_default_chat_template(slm_tokenizer, slm_model_path) | |
| # Load models | |
| slm_model = AutoModelForCausalLM.from_pretrained( | |
| str(slm_model_path), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device} | |
| ).eval() | |
| # Apply generation config to SLM | |
| apply_generation_config(slm_model, generation_config) | |
| if llm_model_path == "google/gemma-3-1b-it": | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| str(llm_model_path), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device}, | |
| sliding_window=4096 | |
| ).eval() | |
| else: | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| str(llm_model_path), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device} | |
| ).eval() | |
| # Apply generation config to LLM | |
| apply_generation_config(llm_model, generation_config) | |
| # Load projectors | |
| num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)]) | |
| projector_list = [] | |
| for t in range(num_projectors): | |
| json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json") | |
| proj = load_projector(json_cfg) | |
| proj = proj.to(device) | |
| pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt") | |
| if os.path.exists(pt_path): | |
| state_dict = torch.load(pt_path, map_location=device) | |
| proj.load_state_dict(state_dict, strict=False) | |
| projector_list.append(proj) | |
| aggregator_list = [] | |
| # Initialize Rosetta model | |
| rosetta_model = RosettaModel( | |
| model_list=[slm_model, llm_model], | |
| base_model_idx=0, | |
| projector_list=projector_list, | |
| aggregator_list=aggregator_list, | |
| ).to(device).eval() | |
| # Load projector/aggregator mapping configs | |
| proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json") | |
| agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json") | |
| rosetta_model.load_projector_config(proj_cfg_path) | |
| rosetta_model.load_aggregator_config(agg_cfg_path) | |
| return rosetta_model, slm_tokenizer | |
| def load_oracle_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any], | |
| device: torch.device) -> Tuple[Any, Any]: | |
| """ | |
| Load Rosetta model with projectors and aggregators. | |
| Args: | |
| model_config: Model configuration dict | |
| eval_config: Evaluation configuration dict | |
| device: Device to load model on | |
| Returns: | |
| Tuple of (rosetta_model, tokenizer) | |
| """ | |
| # Prefer checkpoints_dir under model.rosetta_config; fall back to eval config for backward compatibility | |
| rosetta_config = model_config["rosetta_config"] | |
| checkpoint_dir = rosetta_config.get("checkpoints_dir", eval_config.get("checkpoints_dir")) | |
| if checkpoint_dir is None: | |
| raise KeyError("checkpoints_dir must be provided under model.rosetta_config (preferred) or eval config (legacy)") | |
| slm_model_path = rosetta_config["base_model"] | |
| llm_model_path = rosetta_config["teacher_model"] | |
| # Load tokenizer | |
| slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path)) | |
| set_default_chat_template(slm_tokenizer, slm_model_path) | |
| # Load models | |
| slm_model = AutoModelForCausalLM.from_pretrained( | |
| str(slm_model_path), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device} | |
| ).eval() | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| str(llm_model_path), | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": device} | |
| ).eval() | |
| # Load projectors | |
| num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)]) | |
| projector_list = [] | |
| for t in range(num_projectors): | |
| json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json") | |
| proj = load_projector(json_cfg) | |
| proj = proj.to(device) | |
| pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt") | |
| if os.path.exists(pt_path): | |
| state_dict = torch.load(pt_path, map_location=device) | |
| proj.load_state_dict(state_dict, strict=False) | |
| projector_list.append(proj) | |
| aggregator_list = [] | |
| # Initialize Rosetta model | |
| rosetta_model = OracleRosettaModel( | |
| model_list=[slm_model, llm_model], | |
| base_model_idx=0, | |
| projector_list=projector_list, | |
| aggregator_list=aggregator_list, | |
| ).to(device).eval() | |
| # Load projector/aggregator mapping configs | |
| proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json") | |
| agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json") | |
| rosetta_model.load_projector_config(proj_cfg_path) | |
| rosetta_model.load_aggregator_config(agg_cfg_path) | |
| return rosetta_model, slm_tokenizer | |
| def get_option_token_ids(tokenizer, num_options: int = 4) -> List[int]: | |
| """ | |
| Get token IDs for options A, B, C, D (or more up to J). | |
| Args: | |
| tokenizer: Tokenizer object | |
| num_options: Number of options to get (default 4 for A-D, max 10 for A-J) | |
| Returns: | |
| List of token IDs for options | |
| """ | |
| # Limit to maximum of 10 options (A-J) | |
| num_options = min(num_options, 10) | |
| option_ids = [] | |
| for i in range(num_options): | |
| letter = chr(65 + i) # A=65, B=66, etc. | |
| ids = tokenizer.encode(" " + letter, add_special_tokens=False) | |
| option_ids.append(ids[0] if ids else tokenizer.eos_token_id) | |
| return option_ids | |
| """ | |
| Deprecated | |
| """ | |
| def generate_answer_with_logits(model, tokenizer, prompt: str, option_ids: List[int], | |
| device: torch.device, model_type: str = "hf") -> Tuple[str, np.ndarray]: | |
| """ | |
| Generate answer using logits method. | |
| Args: | |
| model: Model object | |
| tokenizer: Tokenizer object | |
| prompt: Input prompt | |
| option_ids: Token IDs for options A, B, C, D | |
| device: Device to run on | |
| model_type: Type of model ("rosetta", "qwen", or "hf") | |
| Returns: | |
| Tuple of (predicted_answer, probabilities) | |
| """ | |
| messages = [{ | |
| "role": "user", | |
| "content": prompt | |
| }] | |
| # Try to apply chat template | |
| try: | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False if model_type == "qwen" else None | |
| ) | |
| except Exception as e: | |
| print(f"Failed to apply chat template for {model_type} model: {e}") | |
| text = f"### Human: {prompt}\n### Assistant:" | |
| text += "The correct answer is" | |
| input_ids = tokenizer(text, return_tensors="pt").to(device)['input_ids'] | |
| attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device) | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| if model_type == "rosetta": | |
| instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(input_ids.shape[1]-1, 1).unsqueeze(0).to(device) | |
| response_index = torch.tensor([[-1, 0]], dtype=torch.long).unsqueeze(0) | |
| outputs = model.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| kv_cache_index=[instruction_index, response_index] | |
| ) | |
| else: | |
| outputs = model(input_ids) | |
| logits = outputs.logits[0, -1] | |
| option_logits = torch.tensor([ | |
| logits[option_ids[0]].item(), | |
| logits[option_ids[1]].item(), | |
| logits[option_ids[2]].item(), | |
| logits[option_ids[3]].item() | |
| ]) | |
| probs = torch.nn.functional.softmax(option_logits, dim=0).numpy() | |
| pred = chr(65 + np.argmax(probs)) | |
| return pred, probs | |
| def generate_answer_with_generate(model, tokenizer, prompt: str, device: torch.device, | |
| model_type: str = "hf") -> Tuple[str, np.ndarray, int, int, str]: | |
| """ | |
| Generate answer using text generation method. | |
| Args: | |
| model: Model object | |
| tokenizer: Tokenizer object | |
| prompt: Input prompt | |
| device: Device to run on | |
| model_type: Type of model ("rosetta" or "hf") | |
| Returns: | |
| Tuple of (predicted_answer, probabilities, input_length, generation_length, generated_text) | |
| """ | |
| messages = [{ | |
| "role": "user", | |
| "content": prompt | |
| }] | |
| # Apply chat template | |
| try: | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False | |
| ) | |
| except Exception as e: | |
| print(f"Failed to apply chat template: {e}") | |
| text = f"### Human: {prompt}\n### Assistant:" | |
| # Prepare model input | |
| inputs = tokenizer(text, return_tensors="pt").to(device) | |
| # Generation parameters | |
| sampling_params = { | |
| 'do_sample': True, | |
| 'temperature': 0.7, | |
| 'top_p': 0.8, | |
| 'top_k': 20, | |
| 'min_p': 0.0, | |
| 'repetition_penalty': 1.2, | |
| 'max_new_tokens': 1024 | |
| } | |
| # Generate text | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| **sampling_params | |
| ) | |
| # Parse output | |
| if isinstance(model, RosettaModel): | |
| generated_ids = outputs[0] | |
| else: | |
| generated_ids = outputs[0][inputs.input_ids.shape[1]:] | |
| content = tokenizer.decode(generated_ids, skip_special_tokens=True).strip("\n") | |
| # Extract answer | |
| pred = extract_answer_from_content(content) | |
| # Return uniform distribution for generate method | |
| probs = np.array([0.25, 0.25, 0.25, 0.25]) | |
| input_length = inputs.input_ids.shape[1] | |
| gen_length = generated_ids.shape[0] | |
| return pred, probs, input_length, gen_length, content | |