AhsanAftab's picture
Update model_loader.py
4d0c7ca verified
import torch
import torch.nn as nn
from torchvision import models
import pickle
from pathlib import Path
import sys
import logging
# Configure logger
logger = logging.getLogger(__name__)
class Vocabulary:
def __init__(self, freq_threshold=5):
self.freq_threshold = freq_threshold
self.word2idx = {}
self.idx2word = {}
self.idx = 0
# Special tokens
self.pad_token = "<PAD>"
self.start_token = "<SOS>"
self.end_token = "<EOS>"
self.unk_token = "<UNK>"
# Add special tokens
for token in [self.pad_token, self.start_token, self.end_token, self.unk_token]:
self.add_word(token)
def add_word(self, word):
"""Add a word to the vocabulary"""
if word not in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def __len__(self):
return len(self.word2idx)
def __call__(self, word):
"""Convert word to index"""
if word not in self.word2idx:
return self.word2idx[self.unk_token]
return self.word2idx[word]
def decode(self, indices):
"""Convert indices back to words"""
return [self.idx2word[idx] for idx in indices if idx in self.idx2word]
import __main__
setattr(__main__, "Vocabulary", Vocabulary)
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super(EncoderCNN, self).__init__()
resnet = models.resnet50(pretrained=False)
modules = list(resnet.children())[:-1]
self.resnet = nn.Sequential(*modules)
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
def forward(self, images):
features = self.resnet(images)
features = features.view(features.size(0), -1)
features = self.fc(features)
features = self.bn(features)
return features
class DecoderLSTM(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout=0.5):
super(DecoderLSTM, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, features, captions):
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
hiddens, _ = self.lstm(embeddings)
outputs = self.fc(hiddens)
return outputs
def sample(self, features, max_length=50):
batch_size = features.size(0)
captions = []
states = None
inputs = features.unsqueeze(1)
for _ in range(max_length):
hiddens, states = self.lstm(inputs, states)
outputs = self.fc(hiddens.squeeze(1))
predicted = outputs.argmax(dim=1)
captions.append(predicted)
inputs = self.embed(predicted).unsqueeze(1)
captions = torch.stack(captions, dim=1)
return captions
class ImageCaptioningModel(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout=0.5):
super(ImageCaptioningModel, self).__init__()
self.encoder = EncoderCNN(embed_size)
self.decoder = DecoderLSTM(embed_size, hidden_size, vocab_size, num_layers, dropout)
def forward(self, images, captions):
features = self.encoder(images)
outputs = self.decoder(features, captions)
return outputs
def generate_caption(self, images, max_length=50):
features = self.encoder(images)
captions = self.decoder.sample(features, max_length)
return captions
class ActionRecognitionModel(nn.Module):
def __init__(self, num_classes, dropout=0.5):
super(ActionRecognitionModel, self).__init__()
self.backbone = models.resnet50(pretrained=False)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(num_features, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(dropout),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
def load_caption_model(device, model_dir=None):
if model_dir is None:
model_dir = Path(__file__).parent / 'models'
else:
model_dir = Path(model_dir)
# Load configuration
with open(model_dir / 'caption_model_config.pkl', 'rb') as f:
config = pickle.load(f)
# Load vocabulary
try:
with open(model_dir / 'vocab.pkl', 'rb') as f:
vocab = pickle.load(f)
logger.info(f"Vocabulary loaded successfully. Size: {len(vocab)}")
except Exception as e:
logger.error(f"Failed to load vocabulary: {e}")
raise e
# Create model
model = ImageCaptioningModel(
embed_size=config['embed_size'],
hidden_size=config['hidden_size'],
vocab_size=config['vocab_size'],
num_layers=config['num_layers'],
dropout=config['dropout']
)
# Load weights
model.load_state_dict(torch.load(model_dir / 'caption_model_final.pth',
map_location=device))
model = model.to(device)
model.eval()
return model, vocab
def load_action_model(device, model_dir=None):
"""Load action recognition model"""
if model_dir is None:
model_dir = Path(__file__).parent / 'models'
else:
model_dir = Path(model_dir)
# Load configuration
with open(model_dir / 'action_model_config.pkl', 'rb') as f:
config = pickle.load(f)
# Create model
model = ActionRecognitionModel(
num_classes=config['num_classes'],
dropout=config['dropout']
)
# Load weights
model.load_state_dict(torch.load(model_dir / 'action_model_final.pth',
map_location=device))
model = model.to(device)
model.eval()
return model