AliHamza852's picture
Upload 4 files
9f93006 verified
Raw
History Blame Contribute Delete
4.07 kB
import gradio as gr
import torch
import torch.nn as nn
import pickle
from torchvision import models, transforms
from PIL import Image
class Config:
embed_size = 300
hidden_size = 512
num_layers = 1
feature_dim = 2048
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Encoder, self).__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, images):
x = self.linear(images)
x = self.bn(x)
return self.dropout(self.relu(x))
class Decoder(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, features, captions):
return None
class Seq2Seq(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim):
super(Seq2Seq, self).__init__()
self.encoder = Encoder(feature_dim, hidden_size)
self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)
device = torch.device("cpu")
with open('vocab_safe.pkl', 'rb') as f:
vocab_data = pickle.load(f)
itos = vocab_data['itos']
stoi = vocab_data['stoi']
vocab_size = len(itos)
model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device)
resnet.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def generate_caption(image):
try:
if image is None:
return "Please upload an image first."
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
features = resnet(img_tensor).view(1, -1)
with torch.no_grad():
enc_out = model.encoder(features).unsqueeze(0)
h, c = enc_out, enc_out
word_idx = stoi['<start>']
word = torch.tensor(word_idx).view(1).to(device)
caption = []
for i in range(20):
embed = model.decoder.embed(word).view(1, 1, -1)
output, (h, c) = model.decoder.lstm(embed, (h, c))
prediction = model.decoder.linear(output)
idx = prediction.argmax(2).item()
if idx == stoi['<end>']:
break
word_str = itos.get(idx, "<unk>")
caption.append(word_str)
word = torch.tensor(idx).view(1).to(device)
final_caption = " ".join(caption).strip().capitalize()
if final_caption:
final_caption += "."
return final_caption
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ Image Captioning Generator
Upload an image to generate a descriptive caption.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
generate_btn = gr.Button("✨ Generate Caption", variant="primary")
with gr.Column():
caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False)
generate_btn.click(
fn=generate_caption,
inputs=image_input,
outputs=caption_output
)
if __name__ == "__main__":
demo.launch()