| |
| """ |
| Clean Multimodal Gemma3 Loader - No Unsloth bullshit |
| Pure transformers + PEFT implementation |
| """ |
|
|
| import os |
| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from peft import PeftModel |
| import argparse |
|
|
| from multigemma3 import VisionEncoder, VisionProjector, MultimodalGemma3 |
|
|
|
|
| class MultimodalGemma3Inference: |
| """Clean inference class without Unsloth dependencies""" |
|
|
| def __init__(self, device='auto'): |
| """ |
| Initialize the inference model |
| Args: |
| model_dir: Directory containing saved model components |
| device: Device to run on ('auto', 'cuda', 'cpu') |
| """ |
| if device == 'auto': |
| device = "cuda" if torch.cuda.is_available() else 'cpu' |
| self.device = device |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained("./saved_models_clean/best/") |
| self.language_model = AutoModelForCausalLM.from_pretrained( |
| "./saved_models_clean/best", |
| torch_dtype=torch.bfloat16, |
| device_map=device |
| ) |
|
|
| |
| |
| |
|
|
| |
| print("Loading vision encoder...") |
| self.vision_encoder = VisionEncoder().to(device) |
|
|
| |
| projector_path = os.path.join("./saved_models_clean/best/", "projector.pth") |
| print(f"Loading projector from {projector_path}") |
| self.projector = VisionProjector( |
| self.vision_encoder.output_dim, |
| self.language_model.config.hidden_size |
| ).to(device=device, dtype=torch.bfloat16) |
|
|
| self.projector.load_state_dict(torch.load(projector_path, map_location=device)) |
|
|
| |
| self.model = MultimodalGemma3( |
| self.language_model, self.projector, self.tokenizer |
| ).to(device) |
|
|
| |
| self.transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| print("Model loaded successfully!") |
|
|
| def encode_image(self, image_path_or_pil): |
| """ |
| Encode an image to vision embeddings |
| Args: |
| image_path_or_pil: Path to image file or PIL Image |
| Returns: |
| Vision embeddings tensor |
| """ |
| |
| if isinstance(image_path_or_pil, str): |
| image = Image.open(image_path_or_pil).convert('RGB') |
| else: |
| image = image_path_or_pil.convert('RGB') |
|
|
| |
| image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
| |
| with torch.no_grad(): |
| vision_embeds = self.vision_encoder(image_tensor).squeeze(0) |
|
|
| return vision_embeds |
|
|
| def predict(self, image_path_or_pil, prompt="IMG", max_new_tokens=10): |
| """ |
| Predict text response for an image |
| Args: |
| image_path_or_pil: Path to image file or PIL Image |
| prompt: Text prompt (default: "IMG") |
| max_new_tokens: Max tokens to generate |
| Returns: |
| Generated text response |
| """ |
| |
| vision_embeds = self.encode_image(image_path_or_pil) |
|
|
| |
| response = self.model.generate_response( |
| vision_embeds, prompt=prompt, max_new_tokens=max_new_tokens |
| ) |
|
|
| return response |
|
|
| def predict_batch(self, images, prompt="IMG", max_new_tokens=10): |
| """ |
| Predict for a batch of images |
| Args: |
| images: List of image paths or PIL Images |
| prompt: Text prompt |
| max_new_tokens: Max tokens to generate |
| Returns: |
| List of generated responses |
| """ |
| responses = [] |
| for image in images: |
| response = self.predict(image, prompt, max_new_tokens) |
| responses.append(response) |
| return responses |
|
|
| def generate_text(self, prompt, max_new_tokens=50): |
| """ |
| Generate pure text response (no vision) |
| Args: |
| prompt: Text prompt |
| max_new_tokens: Max tokens to generate |
| Returns: |
| Generated text response |
| """ |
| response = self.model.generate_response( |
| vision_embeds=None, |
| prompt=prompt, |
| max_new_tokens=max_new_tokens |
| ) |
| return response |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Load and test Clean Multimodal Gemma3') |
| |
| parser.add_argument('image', type=str, help='Path to image file') |
| parser.add_argument('--prompt', type=str, default='IMG', help='Text prompt') |
| parser.add_argument('--max_tokens', type=int, default=10, help='Max tokens to generate') |
| parser.add_argument('--device', type=str, default='auto', help='Device (auto/cuda/cpu)') |
| args = parser.parse_args() |
|
|
| |
| inference_model = MultimodalGemma3Inference(device=args.device) |
|
|
| |
| print(f"Processing image: {args.image}") |
| response = inference_model.predict( |
| args.image, |
| prompt=args.prompt, |
| max_new_tokens=args.max_tokens |
| ) |
| print(response) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|