#!/usr/bin/env python3 """ Clean Multimodal Gemma3 Loader - No Unsloth bullshit Pure transformers + PEFT implementation """ import os import torch import torchvision.transforms as transforms from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import argparse from multigemma3 import VisionEncoder, VisionProjector, MultimodalGemma3 class MultimodalGemma3Inference: """Clean inference class without Unsloth dependencies""" def __init__(self, device='auto'): """ Initialize the inference model Args: model_dir: Directory containing saved model components device: Device to run on ('auto', 'cuda', 'cpu') """ if device == 'auto': device = "cuda" if torch.cuda.is_available() else 'cpu' self.device = device # Load metadata #metadata_path = os.path.join(model_dir, 'metadata.pth') #metadata = torch.load(metadata_path, map_location=device) #print(f"Loading model from epoch {metadata['epoch']} with accuracy {metadata['accuracy']:.4f}") # Load base language model self.tokenizer = AutoTokenizer.from_pretrained("./saved_models_clean/best/") self.language_model = AutoModelForCausalLM.from_pretrained( "./saved_models_clean/best", torch_dtype=torch.bfloat16, device_map=device ) # Load LoRA adapters #print(f"Loading LoRA adapters from {model_dir}") #self.language_model = PeftModel.from_pretrained(base_language_model, model_dir) # Load vision encoder print("Loading vision encoder...") self.vision_encoder = VisionEncoder().to(device) # Load projector projector_path = os.path.join("./saved_models_clean/best/", "projector.pth") print(f"Loading projector from {projector_path}") self.projector = VisionProjector( self.vision_encoder.output_dim, self.language_model.config.hidden_size ).to(device=device, dtype=torch.bfloat16) self.projector.load_state_dict(torch.load(projector_path, map_location=device)) # Create multimodal model self.model = MultimodalGemma3( self.language_model, self.projector, self.tokenizer ).to(device) # Image preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print("Model loaded successfully!") def encode_image(self, image_path_or_pil): """ Encode an image to vision embeddings Args: image_path_or_pil: Path to image file or PIL Image Returns: Vision embeddings tensor """ # Load image if isinstance(image_path_or_pil, str): image = Image.open(image_path_or_pil).convert('RGB') else: image = image_path_or_pil.convert('RGB') # Preprocess image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Extract vision embeddings with torch.no_grad(): vision_embeds = self.vision_encoder(image_tensor).squeeze(0) return vision_embeds def predict(self, image_path_or_pil, prompt="IMG", max_new_tokens=10): """ Predict text response for an image Args: image_path_or_pil: Path to image file or PIL Image prompt: Text prompt (default: "IMG") max_new_tokens: Max tokens to generate Returns: Generated text response """ # Encode image vision_embeds = self.encode_image(image_path_or_pil) # Generate response response = self.model.generate_response( vision_embeds, prompt=prompt, max_new_tokens=max_new_tokens ) return response def predict_batch(self, images, prompt="IMG", max_new_tokens=10): """ Predict for a batch of images Args: images: List of image paths or PIL Images prompt: Text prompt max_new_tokens: Max tokens to generate Returns: List of generated responses """ responses = [] for image in images: response = self.predict(image, prompt, max_new_tokens) responses.append(response) return responses def generate_text(self, prompt, max_new_tokens=50): """ Generate pure text response (no vision) Args: prompt: Text prompt max_new_tokens: Max tokens to generate Returns: Generated text response """ response = self.model.generate_response( vision_embeds=None, # No vision prompt=prompt, max_new_tokens=max_new_tokens ) return response def main(): parser = argparse.ArgumentParser(description='Load and test Clean Multimodal Gemma3') #parser.add_argument('model_dir', type=str, help='Directory containing saved model') parser.add_argument('image', type=str, help='Path to image file') parser.add_argument('--prompt', type=str, default='IMG', help='Text prompt') parser.add_argument('--max_tokens', type=int, default=10, help='Max tokens to generate') parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)') args = parser.parse_args() # Load model inference_model = MultimodalGemma3Inference(device=args.device) # Process image print(f"Processing image: {args.image}") response = inference_model.predict( args.image, prompt=args.prompt, max_new_tokens=args.max_tokens ) print(response) if __name__ == "__main__": main()