#!/usr/bin/env python3 """ CardVault+ Inference Example Simple example showing how to use the CardVault+ model for card extraction """ import torch from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image, ImageDraw import json def create_sample_card(): """Create a sample credit card image for testing""" # Create card-like image img = Image.new('RGB', (400, 250), color='lightblue') draw = ImageDraw.Draw(img) # Add card elements draw.text((20, 50), "SAMPLE BANK", fill='black') draw.text((20, 100), "1234 5678 9012 3456", fill='black') draw.text((20, 150), "JOHN DOE", fill='black') draw.text((300, 150), "12/25", fill='black') return img def extract_card_info(image_path_or_pil=None): """Extract structured information from a card image""" # Load the model print("Loading CardVault+ model...") model_id = "sugiv/cardvaultplus" processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForVision2Seq.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto" ) # Load image if image_path_or_pil is None: print("Creating sample card image...") image = create_sample_card() elif isinstance(image_path_or_pil, str): image = Image.open(image_path_or_pil) else: image = image_path_or_pil # Prepare extraction prompt prompt = "Extract structured information from this card/document in JSON format." # Process the image and prompt inputs = processor(text=prompt, images=image, return_tensors="pt") # Move to GPU if available device = next(model.parameters()).device inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} # Generate extraction print("Extracting information...") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=False, pad_token_id=processor.tokenizer.eos_token_id ) # Decode response response = processor.decode(outputs[0], skip_special_tokens=True) # Extract JSON if present extracted_json = None if '{' in response and '}' in response: try: json_start = response.find('{') json_end = response.rfind('}') + 1 json_str = response[json_start:json_end] extracted_json = json.loads(json_str) except: pass return { 'full_response': response, 'extracted_json': extracted_json, 'success': extracted_json is not None } if __name__ == "__main__": # Example usage result = extract_card_info() # Uses sample card print("="*50) print("CardVault+ Extraction Results") print("="*50) print(f"Success: {result['success']}") print(f"Full Response: {result['full_response']}") if result['extracted_json']: print("Extracted JSON:") print(json.dumps(result['extracted_json'], indent=2)) # Example with your own image: # result = extract_card_info("path/to/your/card.jpg")