File size: 4,574 Bytes
bb8f662
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
from PIL import Image
from transformers import GPT2Tokenizer
from model import VQAModel
from model_spatial import VQAModelWithSpatialAdapter
from train import Vocab
def load_spatial_model(checkpoint_path, device='cuda'):
    """Load the fine-tuned spatial adapter model"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    vocab = Vocab()
    vocab.vocab = checkpoint['vocab']
    vocab.vocab_size = len(checkpoint['vocab'])
    vocab.word2idx = checkpoint['word2idx']
    vocab.idx2word = checkpoint['idx2word']
    vocab.pad_token_id = checkpoint['pad_token_id']
    vocab.bos_token_id = checkpoint['bos_token_id']
    vocab.eos_token_id = checkpoint['eos_token_id']
    vocab.unk_token_id = checkpoint['unk_token_id']
    tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    base_model = VQAModel(
        vocab_size=len(checkpoint['vocab']),
        device=device,
        question_max_len=checkpoint.get('question_max_len', 20),
        answer_max_len=checkpoint.get('answer_max_len', 12),
        pad_token_id=checkpoint['pad_token_id'],
        bos_token_id=checkpoint['bos_token_id'],
        eos_token_id=checkpoint['eos_token_id'],
        unk_token_id=checkpoint['unk_token_id'],
        hidden_size=512,
        num_layers=2
    ).to(device)
    base_model.gpt2_model.resize_token_embeddings(len(tokenizer))
    model = VQAModelWithSpatialAdapter(
        base_model=base_model,
        hidden_size=512,
        num_heads=8,
        dropout=0.3
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.eval()
    return model, vocab, tokenizer
def answer_question(model, vocab, tokenizer, image_path, question, device='cuda', use_beam_search=True, beam_width=5):
    """Answer a question about an image using the spatial adapter model"""
    image = Image.open(image_path).convert('RGB')
    image = model.clip_preprocess(image).unsqueeze(0).to(device)
    question_tokens = tokenizer(
        question,
        padding='max_length',
        truncation=True,
        max_length=model.question_max_len,
        return_tensors='pt'
    )
    questions = {
        'input_ids': question_tokens['input_ids'].to(device),
        'attention_mask': question_tokens['attention_mask'].to(device)
    }
    with torch.no_grad():
        if use_beam_search and hasattr(model, 'generate_with_beam_search'):
            generated = model.generate_with_beam_search(image, questions, beam_width=beam_width)
        else:
            generated = model(image, questions)
    answer = vocab.decoder(generated[0].cpu().numpy())
    return answer
SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
BASE_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
IMAGE_PATH = r"./im2.jpg"
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("=" * 80)
    print("🧪 TESTING SPATIAL ADAPTER MODEL")
    print("=" * 80)
    spatial_questions = [
        "Whic"
    ]
    print(f"\n📷 Image: {IMAGE_PATH}\n")
    if os.path.exists(SPATIAL_CHECKPOINT):
        print("🔧 Loading SPATIAL ADAPTER model...")
        spatial_model, vocab, tokenizer = load_spatial_model(SPATIAL_CHECKPOINT, device)
        print("✓ Spatial model loaded!\n")
        print("-" * 80)
        print("SPATIAL ADAPTER MODEL RESULTS:")
        print("-" * 80)
        for question in spatial_questions:
            answer = answer_question(spatial_model, vocab, tokenizer, IMAGE_PATH, question, device, use_beam_search=True, beam_width=5)
            print(f"\nQ: {question}")
            print(f"A: {answer}")
        print("\n" + "=" * 80)
    else:
        print(f"⚠️  Spatial model not found at: {SPATIAL_CHECKPOINT}")
        print("   Run finetune2.py first to train the spatial adapter model.")
    if os.path.exists(BASE_CHECKPOINT):
        print("\n🔧 Loading BASE model for comparison...")
        from test import load_model
        base_model, vocab, tokenizer = load_model(BASE_CHECKPOINT, device)
        print("✓ Base model loaded!\n")
        print("-" * 80)
        print("BASE MODEL RESULTS (for comparison):")
        print("-" * 80)
        for question in spatial_questions:
            answer = answer_question(base_model, vocab, tokenizer, IMAGE_PATH, question, device, use_beam_search=True, beam_width=5)
            print(f"\nQ: {question}")
            print(f"A: {answer}")
        print("\n" + "=" * 80)