Spaces:
Sleeping
Sleeping
Commit
·
b1a427a
1
Parent(s):
eb08f58
Upload 5 files
Browse files- app.py +30 -0
- inference.py +39 -0
- loader.py +100 -0
- model.py +74 -0
- train.py +64 -0
app.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import requests
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import torch
|
| 5 |
+
from loader import get_loader
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
|
| 8 |
+
transform = transforms.Compose([
|
| 9 |
+
transforms.Resize(256),
|
| 10 |
+
transforms.CenterCrop(224),
|
| 11 |
+
transforms.ToTensor(),
|
| 12 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 13 |
+
])
|
| 14 |
+
|
| 15 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
| 16 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
| 17 |
+
from model import CNNtoRNN
|
| 18 |
+
model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1)
|
| 19 |
+
model.load_state_dict(torch.load(filepath))
|
| 20 |
+
model.eval()
|
| 21 |
+
|
| 22 |
+
def launch(input):
|
| 23 |
+
im=Image.open(requests.get(input,stream=True).raw)
|
| 24 |
+
image=transform(im.convert('RGB')).unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
return model.caption_image(image,dataset.vocab)
|
| 27 |
+
|
| 28 |
+
iface=gr.Interface(launch,inputs="text",outputs="text")
|
| 29 |
+
iface.launch()
|
| 30 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from model import CNNtoRNN
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from loader import get_loader
|
| 8 |
+
|
| 9 |
+
def inference():
|
| 10 |
+
transform = transforms.Compose([
|
| 11 |
+
transforms.Resize(256),
|
| 12 |
+
transforms.CenterCrop(224),
|
| 13 |
+
transforms.ToTensor(),
|
| 14 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 15 |
+
])
|
| 16 |
+
|
| 17 |
+
image_index=100
|
| 18 |
+
|
| 19 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
| 20 |
+
df=pd.read_csv("FlickrDataset/Captions/captions.txt")
|
| 21 |
+
imagepath="FlickrDataset/Images/"
|
| 22 |
+
images=os.listdir(imagepath)
|
| 23 |
+
im=Image.open(os.path.join(imagepath,images[image_index]))
|
| 24 |
+
im.show()
|
| 25 |
+
|
| 26 |
+
device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 27 |
+
|
| 28 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
| 29 |
+
model=CNNtoRNN(embed_size=256,hidden_size=256,vocab_size=len(dataset.vocab),num_layers=1).to(device)
|
| 30 |
+
model.load_state_dict(torch.load(filepath))
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
image=transform(im.convert("RGB")).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
output=model.caption_image(image.to(device),dataset.vocab)
|
| 36 |
+
print("Output:"+" ".join(output[1:-1]))
|
| 37 |
+
|
| 38 |
+
if __name__=="__main__":
|
| 39 |
+
inference()
|
loader.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import spacy
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
+
from torch.utils.data import DataLoader,Dataset
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
|
| 10 |
+
spacy_eng=spacy.load("en_core_web_sm")
|
| 11 |
+
class Vocabulary:
|
| 12 |
+
def __init__(self,freq_threshold):
|
| 13 |
+
self.itos={0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
|
| 14 |
+
self.stoi={"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
|
| 15 |
+
self.freq_threshold=freq_threshold
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.itos)
|
| 19 |
+
|
| 20 |
+
def tokenizer_eng(self,text):
|
| 21 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
| 22 |
+
|
| 23 |
+
def build_vocabulary(self,sentence_list):
|
| 24 |
+
frequencies={}
|
| 25 |
+
idx=4
|
| 26 |
+
|
| 27 |
+
for sentence in sentence_list:
|
| 28 |
+
for word in self.tokenizer_eng(sentence):
|
| 29 |
+
if word not in frequencies:
|
| 30 |
+
frequencies[word]=1
|
| 31 |
+
|
| 32 |
+
else:
|
| 33 |
+
frequencies[word]+=1
|
| 34 |
+
|
| 35 |
+
if frequencies[word]==self.freq_threshold:
|
| 36 |
+
self.stoi[word]=idx
|
| 37 |
+
self.itos[idx]=word
|
| 38 |
+
idx+=1
|
| 39 |
+
|
| 40 |
+
def numericalize(self,text):
|
| 41 |
+
tokenized_text=self.tokenizer_eng(text)
|
| 42 |
+
return [
|
| 43 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
| 44 |
+
for token in tokenized_text
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
class FlickrDataset(Dataset):
|
| 48 |
+
def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
|
| 49 |
+
self.root_dir=root_dir
|
| 50 |
+
self.df=pd.read_csv(captions_file)
|
| 51 |
+
self.transform=transform
|
| 52 |
+
|
| 53 |
+
self.imgs=self.df['image']
|
| 54 |
+
self.captions=self.df['caption']
|
| 55 |
+
|
| 56 |
+
self.vocab=Vocabulary(freq_threshold)
|
| 57 |
+
self.vocab.build_vocabulary(self.captions.tolist())
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
return len(self.df)
|
| 61 |
+
|
| 62 |
+
def __getitem__(self,index):
|
| 63 |
+
caption=self.captions[index]
|
| 64 |
+
img_id=self.imgs[index]
|
| 65 |
+
img=Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")
|
| 66 |
+
|
| 67 |
+
if self.transform is not None:
|
| 68 |
+
img=self.transform(img)
|
| 69 |
+
|
| 70 |
+
numericalized_caption=[self.vocab.stoi["<SOS>"]]
|
| 71 |
+
numericalized_caption+=self.vocab.numericalize(caption)
|
| 72 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
| 73 |
+
|
| 74 |
+
return img,torch.tensor(numericalized_caption)
|
| 75 |
+
|
| 76 |
+
class MyCollate:
|
| 77 |
+
def __init__(self,pad_idx):
|
| 78 |
+
self.pad_idx=pad_idx
|
| 79 |
+
|
| 80 |
+
def __call__(self,batch):
|
| 81 |
+
imgs=[item[0].unsqueeze(0) for item in batch]
|
| 82 |
+
imgs=torch.cat(imgs,dim=0)
|
| 83 |
+
targets=[item[1] for item in batch]
|
| 84 |
+
targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
|
| 85 |
+
|
| 86 |
+
return imgs,targets
|
| 87 |
+
|
| 88 |
+
def get_loader(root_folder,annotation_file,transform,batch_size=32,shuffle=True,pin_memory=True,num_workers=8):
|
| 89 |
+
dataset=FlickrDataset(root_folder,annotation_file,transform=transform)
|
| 90 |
+
pad_idx=dataset.vocab.stoi["<PAD>"]
|
| 91 |
+
loader=DataLoader(
|
| 92 |
+
dataset=dataset,
|
| 93 |
+
batch_size=batch_size,
|
| 94 |
+
num_workers=num_workers,
|
| 95 |
+
shuffle=shuffle,
|
| 96 |
+
pin_memory=pin_memory,
|
| 97 |
+
collate_fn=MyCollate(pad_idx=pad_idx)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return loader,dataset
|
model.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models as models
|
| 4 |
+
|
| 5 |
+
class EncoderCNN(nn.Module):
|
| 6 |
+
def __init__(self,embed_size):
|
| 7 |
+
super(EncoderCNN, self).__init__()
|
| 8 |
+
resnet = models.resnet50(weights='ResNet50_Weights.DEFAULT')
|
| 9 |
+
for param in resnet.parameters():
|
| 10 |
+
param.requires_grad_(False)
|
| 11 |
+
|
| 12 |
+
modules = list(resnet.children())[:-1]
|
| 13 |
+
self.resnet = nn.Sequential(*modules)
|
| 14 |
+
self.embed = nn.Linear(resnet.fc.in_features, embed_size)
|
| 15 |
+
self.batch= nn.BatchNorm1d(embed_size,momentum = 0.01)
|
| 16 |
+
self.embed.weight.data.normal_(0., 0.02)
|
| 17 |
+
self.embed.bias.data.fill_(0)
|
| 18 |
+
|
| 19 |
+
def forward(self,images):
|
| 20 |
+
features = self.resnet(images)
|
| 21 |
+
features = features.view(features.size(0), -1)
|
| 22 |
+
features = self.batch(self.embed(features))
|
| 23 |
+
return features
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DecoderRNN(nn.Module):
|
| 27 |
+
def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
|
| 28 |
+
super(DecoderRNN, self).__init__()
|
| 29 |
+
self.embed=nn.Embedding(vocab_size,embed_size)
|
| 30 |
+
self.lstm=nn.LSTM(embed_size,hidden_size,num_layers)
|
| 31 |
+
self.linear=nn.Linear(hidden_size,vocab_size)
|
| 32 |
+
self.dropout=nn.Dropout(0.5)
|
| 33 |
+
|
| 34 |
+
def forward(self,features,captions):
|
| 35 |
+
embeddings=self.dropout(self.embed(captions))
|
| 36 |
+
embeddings=torch.cat((features.unsqueeze(0),embeddings),dim=0)
|
| 37 |
+
hiddens,_=self.lstm(embeddings)
|
| 38 |
+
outputs=self.linear(hiddens)
|
| 39 |
+
|
| 40 |
+
return outputs
|
| 41 |
+
|
| 42 |
+
class CNNtoRNN(nn.Module):
|
| 43 |
+
def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
|
| 44 |
+
super(CNNtoRNN,self).__init__()
|
| 45 |
+
self.encoderCNN=EncoderCNN(embed_size)
|
| 46 |
+
self.decoderRNN=DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)
|
| 47 |
+
|
| 48 |
+
def forward(self,images,captions):
|
| 49 |
+
features=self.encoderCNN(images)
|
| 50 |
+
outputs=self.decoderRNN(features,captions)
|
| 51 |
+
return outputs
|
| 52 |
+
|
| 53 |
+
def caption_image(self,image,vocabulary,max_length=50):
|
| 54 |
+
result_caption=[]
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
X=self.encoderCNN(image).unsqueeze(0)
|
| 57 |
+
states=None
|
| 58 |
+
|
| 59 |
+
for _ in range(max_length):
|
| 60 |
+
hiddens,states=self.decoderRNN.lstm(X,states)
|
| 61 |
+
output=self.decoderRNN.linear(hiddens.squeeze(0))
|
| 62 |
+
predicted=output.argmax(1)
|
| 63 |
+
result_caption.append(predicted.item())
|
| 64 |
+
|
| 65 |
+
X=self.decoderRNN.embed(predicted).unsqueeze(0)
|
| 66 |
+
|
| 67 |
+
if vocabulary.itos[predicted.item()]=="<EOS>":
|
| 68 |
+
break
|
| 69 |
+
|
| 70 |
+
return [vocabulary.itos[idx] for idx in result_caption]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
train.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
from loader import get_loader
|
| 6 |
+
from model import CNNtoRNN
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from tqdm import trange
|
| 9 |
+
|
| 10 |
+
def train():
|
| 11 |
+
transform = transforms.Compose([
|
| 12 |
+
transforms.Resize(256),
|
| 13 |
+
transforms.CenterCrop(224),
|
| 14 |
+
transforms.ToTensor(),
|
| 15 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
train_loader,dataset=get_loader(root_folder='FlickrDataset/Images',annotation_file='FlickrDataset/Captions/captions.txt',transform=transform,num_workers=2)
|
| 19 |
+
|
| 20 |
+
torch.backends.cudnn.benchmark=True
|
| 21 |
+
device=torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 22 |
+
embed_size=256
|
| 23 |
+
hidden_size=256
|
| 24 |
+
vocab_size=len(dataset.vocab)
|
| 25 |
+
num_layers=1
|
| 26 |
+
learning_rate=3e-4
|
| 27 |
+
num_epochs=200
|
| 28 |
+
|
| 29 |
+
model=CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
|
| 30 |
+
criterion=nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
|
| 31 |
+
optimizer=optim.Adam(model.parameters(),lr=learning_rate)
|
| 32 |
+
train_iterator=trange(0,num_epochs)
|
| 33 |
+
for _ in train_iterator:
|
| 34 |
+
pbar=tqdm(train_loader)
|
| 35 |
+
for idx,(imgs,captions) in enumerate(pbar):
|
| 36 |
+
model.train()
|
| 37 |
+
imgs=imgs.to(device)
|
| 38 |
+
captions=captions.to(device)
|
| 39 |
+
|
| 40 |
+
outputs=model(imgs,captions[:-1])
|
| 41 |
+
loss=criterion(outputs.reshape(-1,outputs.shape[2]),captions.reshape(-1))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
loss.backward()
|
| 45 |
+
optimizer.step()
|
| 46 |
+
optimizer.zero_grad()
|
| 47 |
+
|
| 48 |
+
pbar.set_postfix(loss=loss.item())
|
| 49 |
+
|
| 50 |
+
filepath="ImageCaptioningusingLSTM.pth"
|
| 51 |
+
torch.save(model.state_dict(),filepath)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__=="__main__":
|
| 56 |
+
train()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|