Spaces:
Sleeping
Sleeping
File size: 5,183 Bytes
e064477 a846416 87cea5c 737f1a6 a846416 e064477 72ca0c5 e064477 a846416 e064477 72ca0c5 e064477 72ca0c5 e064477 f90ca7e e064477 72ca0c5 e064477 72ca0c5 e064477 72ca0c5 e064477 72ca0c5 e064477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "120" # Increase timeout to 120 seconds
# Set the cache directory to /tmp/cache, which is usually writable in most environments
os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache"
os.makedirs("/tmp/cache", exist_ok=True)
import torch
import torch.nn as nn
from transformers import ViTModel, T5ForConditionalGeneration, T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput
import requests
class ViTT5(nn.Module):
def __init__(self, vit_encoder, t5_decoder):
super(ViTT5, self).__init__()
self.vit_encoder = vit_encoder
self.t5_decoder = t5_decoder
# Project ViT's hidden size (768) to T5's d_model (512 for t5-small)
self.projection = nn.Linear(vit_encoder.config.hidden_size, t5_decoder.config.d_model)
def forward(self, pixel_values, labels=None):
# Extract ViT encoder outputs
vit_outputs = self.vit_encoder(pixel_values=pixel_values)
vit_hidden_states = vit_outputs.last_hidden_state
# Project to T5's dimension
encoder_hidden_states = self.projection(vit_hidden_states)
# Pass to T5 decoder
outputs = self.t5_decoder(
encoder_outputs=(encoder_hidden_states,), # T5 expects a tuple
labels=labels
)
return outputs
def generate(self, pixel_values, **kwargs):
# Extract ViT encoder outputs
vit_outputs = self.vit_encoder(pixel_values=pixel_values)
vit_hidden_states = vit_outputs.last_hidden_state
# Project to T5's dimension
encoder_hidden_states = self.projection(vit_hidden_states)
# Wrap the hidden states in a BaseModelOutput which has a last_hidden_state attribute
encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)
# Generate captions using T5's decoder
return self.t5_decoder.generate(
encoder_outputs=encoder_outputs,
no_repeat_ngram_size=2, # Prevents repeating phrases
repetition_penalty=1.2, # Penalizes repeated words
temperature=0.9, # More diverse outputs
**kwargs
)
def download_checkpoint(checkpoint_path):
"""
Downloads the checkpoint from Hugging Face Model Hub if not found locally.
"""
model_url = "https://huggingface.co/Rishabh2234/image-caption-generator/resolve/main/checkpoint.pth"
print("Checkpoint not found locally. Downloading from Hugging Face Model Hub...")
response = requests.get(model_url, stream=True)
if response.status_code == 200:
with open(checkpoint_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete!")
else:
raise RuntimeError(f"Error downloading model, status code: {response.status_code}")
def load_model(checkpoint_path, device):
"""
Loads the ViTT5 model along with the T5 tokenizer.
Downloads the checkpoint from Hugging Face Hub if not available locally.
"""
# Load pre-trained vision encoder and T5 decoder
encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
decoder = T5ForConditionalGeneration.from_pretrained("t5-small")
# Initialize the combined model
model = ViTT5(encoder, decoder).to(device)
# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")
# Configure decoder settings
model.t5_decoder.config.decoder_start_token_id = tokenizer.pad_token_id
model.t5_decoder.config.eos_token_id = tokenizer.eos_token_id
model.t5_decoder.config.pad_token_id = tokenizer.pad_token_id
model.t5_decoder.config.max_length = 40
model.t5_decoder.config.num_beams = 6
model.t5_decoder.config.repetition_penalty = 1.2
model.t5_decoder.config.no_repeat_ngram_size = 2
model.t5_decoder.config.temperature = 0.9
model.to(device)
# Check if checkpoint exists; if not, download it
if not os.path.exists(checkpoint_path):
download_checkpoint(checkpoint_path)
# Load checkpoint if available
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
# Use strict=False to allow minor key mismatches (if any)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
print("Checkpoint loaded from", checkpoint_path)
if missing_keys:
print("Missing keys:", missing_keys)
if unexpected_keys:
print("Unexpected keys:", unexpected_keys)
else:
print("No checkpoint found even after download. Using base pre-trained model.")
return model, tokenizer
# For deployment, when running this file directly:
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "checkpoint.pth" # Save in the current directory
model, tokenizer = load_model(checkpoint_path, device)
print("Model and tokenizer loaded successfully!")
|