File size: 5,183 Bytes
e064477
a846416
87cea5c
737f1a6
 
 
a846416
e064477
 
 
 
72ca0c5
e064477
a846416
 
e064477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ca0c5
 
 
 
e064477
72ca0c5
 
 
 
 
 
 
 
 
e064477
f90ca7e
e064477
 
 
72ca0c5
e064477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72ca0c5
 
 
 
e064477
72ca0c5
 
 
 
 
 
 
 
 
 
 
e064477
 
 
 
 
 
72ca0c5
e064477
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os

os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "120"  # Increase timeout to 120 seconds
# Set the cache directory to /tmp/cache, which is usually writable in most environments
os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache"
os.makedirs("/tmp/cache", exist_ok=True)

import torch
import torch.nn as nn
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput
import requests



class ViTT5(nn.Module):
    def __init__(self, vit_encoder, t5_decoder):
        super(ViTT5, self).__init__()
        self.vit_encoder = vit_encoder
        self.t5_decoder = t5_decoder
        # Project ViT's hidden size (768) to T5's d_model (512 for t5-small)
        self.projection = nn.Linear(vit_encoder.config.hidden_size, t5_decoder.config.d_model)
        
    def forward(self, pixel_values, labels=None):
        # Extract ViT encoder outputs
        vit_outputs = self.vit_encoder(pixel_values=pixel_values)
        vit_hidden_states = vit_outputs.last_hidden_state
        
        # Project to T5's dimension
        encoder_hidden_states = self.projection(vit_hidden_states)
        
        # Pass to T5 decoder
        outputs = self.t5_decoder(
            encoder_outputs=(encoder_hidden_states,),  # T5 expects a tuple
            labels=labels
        )
        return outputs
    
    def generate(self, pixel_values, **kwargs):
        # Extract ViT encoder outputs
        vit_outputs = self.vit_encoder(pixel_values=pixel_values)
        vit_hidden_states = vit_outputs.last_hidden_state
        
        # Project to T5's dimension
        encoder_hidden_states = self.projection(vit_hidden_states)
        
        # Wrap the hidden states in a BaseModelOutput which has a last_hidden_state attribute
        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)
        
        # Generate captions using T5's decoder
        return self.t5_decoder.generate(
            encoder_outputs=encoder_outputs,
            no_repeat_ngram_size=2,  # Prevents repeating phrases
            repetition_penalty=1.2,  # Penalizes repeated words
            temperature=0.9,         # More diverse outputs
            **kwargs
        )
def download_checkpoint(checkpoint_path):
    """
    Downloads the checkpoint from Hugging Face Model Hub if not found locally.
    """
    model_url = "https://huggingface.co/Rishabh2234/image-caption-generator/resolve/main/checkpoint.pth"
    print("Checkpoint not found locally. Downloading from Hugging Face Model Hub...")
    response = requests.get(model_url, stream=True)
    if response.status_code == 200:
        with open(checkpoint_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
        print("Download complete!")
    else:
        raise RuntimeError(f"Error downloading model, status code: {response.status_code}")


def load_model(checkpoint_path, device):
    """
    Loads the ViTT5 model along with the T5 tokenizer.
    Downloads the checkpoint from Hugging Face Hub if not available locally.
    """
    # Load pre-trained vision encoder and T5 decoder
    encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
    decoder = T5ForConditionalGeneration.from_pretrained("t5-small")
    
    # Initialize the combined model
    model = ViTT5(encoder, decoder).to(device)
    
    # Load the tokenizer
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    
    # Configure decoder settings
    model.t5_decoder.config.decoder_start_token_id = tokenizer.pad_token_id
    model.t5_decoder.config.eos_token_id = tokenizer.eos_token_id
    model.t5_decoder.config.pad_token_id = tokenizer.pad_token_id
    model.t5_decoder.config.max_length = 40
    model.t5_decoder.config.num_beams = 6
    model.t5_decoder.config.repetition_penalty = 1.2
    model.t5_decoder.config.no_repeat_ngram_size = 2
    model.t5_decoder.config.temperature = 0.9
    model.to(device)
    
    # Check if checkpoint exists; if not, download it
    if not os.path.exists(checkpoint_path):
        download_checkpoint(checkpoint_path)
    
    # Load checkpoint if available
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # Use strict=False to allow minor key mismatches (if any)
        missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        print("Checkpoint loaded from", checkpoint_path)
        if missing_keys:
            print("Missing keys:", missing_keys)
        if unexpected_keys:
            print("Unexpected keys:", unexpected_keys)
    else:
        print("No checkpoint found even after download. Using base pre-trained model.")
    
    return model, tokenizer

# For deployment, when running this file directly:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint_path = "checkpoint.pth"  # Save in the current directory
    model, tokenizer = load_model(checkpoint_path, device)
    print("Model and tokenizer loaded successfully!")