File size: 3,194 Bytes
24dc13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""
CardVault+ Inference Example
Simple example showing how to use the CardVault+ model for card extraction
"""

import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image, ImageDraw
import json

def create_sample_card():
    """Create a sample credit card image for testing"""
    # Create card-like image
    img = Image.new('RGB', (400, 250), color='lightblue')
    draw = ImageDraw.Draw(img)
    
    # Add card elements
    draw.text((20, 50), "SAMPLE BANK", fill='black')
    draw.text((20, 100), "1234 5678 9012 3456", fill='black')
    draw.text((20, 150), "JOHN DOE", fill='black') 
    draw.text((300, 150), "12/25", fill='black')
    
    return img

def extract_card_info(image_path_or_pil=None):
    """Extract structured information from a card image"""
    
    # Load the model
    print("Loading CardVault+ model...")
    model_id = "sugiv/cardvaultplus" 
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForVision2Seq.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # Load image
    if image_path_or_pil is None:
        print("Creating sample card image...")
        image = create_sample_card()
    elif isinstance(image_path_or_pil, str):
        image = Image.open(image_path_or_pil)
    else:
        image = image_path_or_pil
    
    # Prepare extraction prompt
    prompt = "<image>Extract structured information from this card/document in JSON format."
    
    # Process the image and prompt
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    
    # Move to GPU if available
    device = next(model.parameters()).device
    inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
    
    # Generate extraction
    print("Extracting information...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=False,
            pad_token_id=processor.tokenizer.eos_token_id
        )
    
    # Decode response
    response = processor.decode(outputs[0], skip_special_tokens=True)
    
    # Extract JSON if present
    extracted_json = None
    if '{' in response and '}' in response:
        try:
            json_start = response.find('{')
            json_end = response.rfind('}') + 1
            json_str = response[json_start:json_end]
            extracted_json = json.loads(json_str)
        except:
            pass
    
    return {
        'full_response': response,
        'extracted_json': extracted_json,
        'success': extracted_json is not None
    }

if __name__ == "__main__":
    # Example usage
    result = extract_card_info()  # Uses sample card
    
    print("="*50)
    print("CardVault+ Extraction Results")
    print("="*50)
    print(f"Success: {result['success']}")
    print(f"Full Response: {result['full_response']}")
    
    if result['extracted_json']:
        print("Extracted JSON:")
        print(json.dumps(result['extracted_json'], indent=2))
    
    # Example with your own image:
    # result = extract_card_info("path/to/your/card.jpg")