import base64 import os import warnings from io import BytesIO from typing import Dict, List, Optional, Tuple, Union import torch from accelerate import Accelerator, DistributedType from loguru import logger as eval_logger from PIL import Image from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration from lmms_eval import utils from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model warnings.simplefilter("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore") # Constants for default pixel values DEFAULT_MIN_PIXELS = 256 * 28 * 28 DEFAULT_MAX_PIXELS = 1605632 DEFAULT_MAX_FRAMES = 32 @register_model("gemma3") class Gemma3(lmms): """ Gemma3 Model https://huggingface.co/google/gemma-3-27b-it """ def __init__( self, pretrained: str = "google/gemma-3-27b-it", device: Optional[str] = "cuda", device_map: Optional[str] = "auto", batch_size: Optional[Union[int, str]] = 1, trust_remote_code: Optional[bool] = True, use_cache=True, attn_implementation: Optional[str] = None, min_pixels: int = DEFAULT_MIN_PIXELS, max_pixels: int = DEFAULT_MAX_PIXELS, max_num_frames: int = DEFAULT_MAX_FRAMES, interleave_visuals: Optional[bool] = False, system_prompt: Optional[str] = "You are a helpful assistant.", reasoning_prompt: Optional[str] = None, **kwargs, ) -> None: super().__init__() # Do not use kwargs for now assert kwargs == {}, f"Unexpected kwargs: {kwargs}" accelerator = Accelerator() if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.device_map = f"cuda:{accelerator.local_process_index}" else: self._device = torch.device(device) self.device_map = device_map if device_map else device # Prepare model loading arguments model_kwargs = { "torch_dtype": torch.bfloat16, "device_map": self.device_map, } # Add attention implementation if specified if attn_implementation is not None: model_kwargs["attn_implementation"] = attn_implementation # Minimal, generation-capable loader: use the dedicated Gemma3 class self._model = Gemma3ForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval() self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map) self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) self._config = self._model.config self._max_length = kwargs.get("max_length", 2048) self._model.tie_weights() self.batch_size_per_gpu = int(batch_size) self.use_cache = use_cache self.system_prompt = system_prompt self.interleave_visuals = interleave_visuals self.max_pixels = max_pixels self.min_pixels = min_pixels self.max_num_frames = max_num_frames if reasoning_prompt: self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n") else: self.reasoning_prompt = None if accelerator.num_processes > 1: assert accelerator.distributed_type in [ DistributedType.FSDP, DistributedType.MULTI_GPU, ], "Unsupported distributed type provided. Only DDP and FSDP are supported." if accelerator.distributed_type == DistributedType.FSDP: self._model = accelerator.prepare(self.model) else: self._model = accelerator.prepare_model(self.model, evaluation_mode=True) self.accelerator = accelerator if self.accelerator.is_local_main_process: eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") self._rank = self.accelerator.local_process_index self._world_size = self.accelerator.num_processes else: self.model.to(self._device) self._rank = 0 self._world_size = 1 self.model.eval() @property def config(self): # return the associated transformers.AutoConfig for the given pretrained model. return self._config @property def tokenizer(self): return self._tokenizer @property def model(self): # returns the model, unwrapping it if using Accelerate if hasattr(self, "accelerator"): return self.accelerator.unwrap_model(self._model) else: return self._model @property def eot_token_id(self): # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # return self.tokenizer.eod_id return self.tokenizer.eos_token_id @property def max_length(self): return self._max_length @property def batch_size(self): return self.batch_size_per_gpu @property def device(self): return self._device @property def rank(self): return self._rank @property def world_size(self): return self._world_size def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: raise NotImplementedError("Not implemented for Gemma3.") def flatten(self, input: List[List]) -> List: """Flatten a nested list into a single list. Args: input: A nested list structure Returns: A flattened single-level list """ new_list = [] for i in input: for j in i: new_list.append(j) return new_list def generate_until(self, requests: List[Instance]) -> List[str]: """Generate text completions for given requests. Args: requests: List of Instance objects containing generation requests Returns: List of generated text responses """ res = [] def _collate(x): # the negative sign on len(toks) sorts descending - this has a few advantages: # - time estimates will always be over not underestimates, which is more useful for planning # - to know the size of a batch when going through the list, you know the first one is always the batch # padded context length. this is useful to simplify the batching logic and more importantly to make # automatic adaptive batches much much easier to implement # - any OOMs will happen right away rather than near the end toks = self.tokenizer.encode(x[0]) return -len(toks), x[0] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) for chunk in chunks: contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] gen_kwargs = all_gen_kwargs[0] # Set default until or update values from gen_kwargs if present until = gen_kwargs.get("until", [self.tokenizer.decode(self.eot_token_id)]) if isinstance(until, str): until = [until] elif not isinstance(until, list): raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}") # Avoid using '\n\n' as a stopper to prevent truncation, which can lead to incorrect results until = [item for item in until if item != "\n\n"] if isinstance(contexts, tuple): contexts = list(contexts) for i in range(len(contexts)): if "" in contexts[i]: contexts[i] = contexts[i].replace("", "") batched_messages = [] for i, context in enumerate(contexts): if "" in context: context = context.replace("", "") message = [{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}] if self.reasoning_prompt: context = context.strip() + self.reasoning_prompt contexts[i] = context processed_visuals = [] for visual in visual_list[i]: try: if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file if not os.path.exists(visual): eval_logger.warning(f"Video file not found: {visual}") continue processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels}) elif isinstance(visual, Image.Image): # Handle both single and multiple images base64_image = visual.convert("RGB") buffer = BytesIO() base64_image.save(buffer, format="JPEG") base64_bytes = base64.b64encode(buffer.getvalue()) base64_string = base64_bytes.decode("utf-8") processed_visuals.append({"type": "image", "image": f"data:image/jpeg;base64,{base64_string}", "max_pixels": self.max_pixels, "min_pixels": self.min_pixels}) except Exception as e: eval_logger.error(f"Failed to process visual: {e}") continue message.append( { "role": "user", "content": processed_visuals + [{"type": "text", "text": context}], } ) batched_messages.append(message) inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to( self.model.device, dtype=torch.bfloat16 ) if self.device_map == "auto": inputs = inputs.to("cuda") else: inputs = inputs.to(self.device) # Set default generation kwargs default_gen_kwargs = { "max_new_tokens": 128, "temperature": 0.0, # Set to 0 for greedy default "top_p": None, "num_beams": 1, } # Update with provided kwargs current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs} if current_gen_kwargs["temperature"] > 0: current_gen_kwargs["do_sample"] = True else: current_gen_kwargs["do_sample"] = False current_gen_kwargs["temperature"] = None current_gen_kwargs["top_p"] = None cont = self.model.generate( **inputs, do_sample=current_gen_kwargs["do_sample"], temperature=current_gen_kwargs["temperature"], top_p=current_gen_kwargs["top_p"], num_beams=current_gen_kwargs["num_beams"], max_new_tokens=current_gen_kwargs["max_new_tokens"], use_cache=self.use_cache, ) generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) for i, ans in enumerate(answers): # print(f"Raw answer {i}: {ans}") for term in until: if len(term) > 0: ans = ans.split(term)[0] answers[i] = ans for ans, context in zip(answers, contexts): res.append(ans) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans) pbar.update(1) # reorder this group of results back to original unsorted form res = re_ords.get_original(res) pbar.close() return res def generate_until_multi_round(self, requests: List[Instance]) -> List[str]: """Generate text in a multi-round conversation format. Args: requests: List of Instance objects for multi-round generation Returns: List of generated responses Raises: NotImplementedError: This method is not yet implemented """ raise NotImplementedError("TODO: Implement multi-round generation")