import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from transformers import ViTModel from PIL import Image import pickle import re import os class Vocabulary: def __init__(self, freq_threshold=5): self.freq_threshold = freq_threshold self.itos = {0: "", 1: "", 2: "", 3: ""} self.stoi = {v: k for k, v in self.itos.items()} self.index = 4 def __len__(self): return len(self.itos) def tokenizer(self, text): text = text.lower() tokens = re.findall(r"\w+", text) return tokens def numericalize(self, text): tokens = self.tokenizer(text) numericalized = [] for token in tokens: if token in self.stoi: numericalized.append(self.stoi[token]) else: numericalized.append(self.stoi[""]) return numericalized class Encoder(nn.Module): def __init__(self, embed_dim, freeze=False): super().__init__() self.vit = ViTModel.from_pretrained("facebook/vit-mae-base") if freeze: for param in self.vit.parameters(): param.requires_grad = False self.linear = nn.Sequential( nn.Linear(self.vit.config.hidden_size, embed_dim), nn.ReLU(), nn.Dropout(0.1), nn.Linear(embed_dim, embed_dim), nn.LayerNorm(embed_dim) ) def forward(self, images): outputs = self.vit(pixel_values=images) patch_embeddings = outputs.last_hidden_state[:, 1:, :] features = self.linear(patch_embeddings) return features class MultiHeadAttention(nn.Module): def __init__(self, hidden_dim, encoder_dim, num_heads=4): super().__init__() self.num_heads = num_heads self.hidden_dim = hidden_dim self.head_dim = hidden_dim // num_heads assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" self.query = nn.Linear(hidden_dim, hidden_dim) self.key = nn.Linear(encoder_dim, hidden_dim) self.value = nn.Linear(encoder_dim, hidden_dim) self.fc_out = nn.Linear(hidden_dim, encoder_dim) def forward(self, hidden, encoder_outputs): B, N, _ = encoder_outputs.shape Q = self.query(hidden).view(B, self.num_heads, self.head_dim) K = self.key(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) V = self.value(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q.unsqueeze(2), K.transpose(-2, -1)) / (self.head_dim ** 0.5) attn = torch.softmax(scores, dim=-1) context = torch.matmul(attn, V) context = context.transpose(1, 2).contiguous().view(B, self.hidden_dim) return self.fc_out(context) class Decoder(nn.Module): def __init__(self, embed_dim, hidden_dim, vocab_size, encoder_dim=256, num_layers=2, dropout=0.3, num_heads=4): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.dropout = nn.Dropout(dropout) self.lstm = nn.LSTM(embed_dim + encoder_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) self.attention = MultiHeadAttention(hidden_dim, encoder_dim, num_heads=num_heads) self.fc = nn.Linear(hidden_dim, vocab_size) def generate(self, features, max_len=50, start_index=1, end_index=2, beam_size=3, beam_search=True): B = features.size(0) device = features.device states = (torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device), torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device)) if not beam_search: generated = [] current_token = torch.LongTensor([start_index]).to(device).unsqueeze(0) for _ in range(max_len): emb = self.embedding(current_token).squeeze(1) context = self.attention(states[0][-1], features) lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1) out, states = self.lstm(lstm_input, states) logits = self.fc(out.squeeze(1)) predicted = logits.argmax(dim=1).item() generated.append(predicted) if predicted == end_index: break current_token = torch.LongTensor([predicted]).to(device).unsqueeze(0) return generated else: beams = [([start_index], 0.0, states) for _ in range(beam_size)] for _ in range(max_len): new_beams = [] for seq, log_prob, (h, c) in beams: current_token = torch.LongTensor([seq[-1]]).to(device).unsqueeze(0) emb = self.embedding(current_token).squeeze(1) context = self.attention(h[-1], features) lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1) out, (h_new, c_new) = self.lstm(lstm_input, (h, c)) logits = self.fc(out.squeeze(1)) log_probs = torch.log_softmax(logits, dim=1) top_log_probs, top_indices = log_probs.topk(beam_size, dim=1) for k in range(beam_size): next_seq = seq + [top_indices[0, k].item()] next_log_prob = log_prob + top_log_probs[0, k].item() new_beams.append((next_seq, next_log_prob, (h_new, c_new))) new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size] beams = new_beams if all(seq[-1] == end_index for seq, _, _ in beams): break best_seq = beams[0][0] if best_seq[0] == start_index: best_seq = best_seq[1:] return best_seq class Model(nn.Module): def __init__(self, encoder, decoder): super().__init__() self.encoder = encoder self.decoder = decoder def generate(self, images, max_len=50): features = self.encoder(images) captions = self.decoder.generate(features, max_len=max_len, beam_search=True) return captions DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EMBED_DIM = 256 HIDDEN_DIM = 512 VOCAB_PATH = "vocab-v4.pkl" MODEL_PATH = "vit_lstm_best-v4.pth" vocab = None model = None inference_transform = None def load_system(): global vocab, model, inference_transform print("Loading Vocabulary...") try: with open(VOCAB_PATH, "rb") as f: vocab = pickle.load(f) except Exception as e: return f"Error loading vocab: {e}" print("Initializing Model...") encoder = Encoder(EMBED_DIM, freeze=True) decoder = Decoder(EMBED_DIM, HIDDEN_DIM, len(vocab)) model = Model(encoder, decoder).to(DEVICE) print("Loading Weights...") try: checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.eval() except Exception as e: return f"Error loading model weights: {e}" inference_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return "System Loaded" load_status = load_system() def format_loading_html(): return """

Status

Analyzing Image...

...

Please wait...

""" def format_result_html(caption): return f"""

Generated Caption

"{caption}"
""" def format_initial_html(): return """

Output Area

Your generated caption will appear here.

""" def format_error_html(error): return f"""

Error

{error}

""" def predict(image): if image is None: yield format_error_html("No image uploaded"), "", gr.update(variant="secondary") return yield format_loading_html(), "", gr.update(variant="secondary") try: pil_image = image.convert("RGB") image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): output_indices = model.generate(image_tensor) result_words = [] for idx in output_indices: word = vocab.itos.get(idx, "") if word == "": break if word not in ("", ""): result_words.append(word) caption = " ".join(result_words) yield format_result_html(caption), caption, gr.update(variant="primary") except Exception as e: yield format_error_html(str(e)), "", gr.update(variant="secondary") js_head = """ """ custom_css = """ body { background-color: #111827; } @keyframes fadeInUp { from { opacity: 0; transform: translateY(20px); } to { opacity: 1; transform: translateY(0); } } .container { max-width: 800px; margin: auto; padding-top: 20px; } .header { text-align: center; margin-bottom: 30px; } .header h1 { color: #818CF8; font-size: 2.5rem; } .header p { color: #9CA3AF; } @keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } } .loading-box { animation: pulse 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite; } button { transition: all 0.2s ease-in-out !important; } button:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3); } button:active { transform: translateY(0); } @keyframes popIn { 0% { opacity: 0; transform: scale(0.9); } 70% { transform: scale(1.02); } 100% { opacity: 1; transform: scale(1); } } .result-animation { animation: popIn 0.5s cubic-bezier(0.175, 0.885, 0.32, 1.275) forwards; } @keyframes slideUpFadeIn { from { opacity: 0; transform: translate(-50%, 100%); } to { opacity: 1; transform: translate(-50%, 0); } } @keyframes fadeOutSlideDown { from { opacity: 1; transform: translate(-50%, 0); } to { opacity: 0; transform: translate(-50%, 100%); } } .toast { position: fixed; bottom: 30px; left: 50%; transform: translate(-50%, 0); padding: 12px 24px; border-radius: 8px; z-index: 10000; box-shadow: 0 4px 12px rgba(0,0,0,0.15); font-weight: 500; animation: slideUpFadeIn 0.5s ease forwards; } .toast.hiding { animation: fadeOutSlideDown 0.5s ease forwards; } #custom-api-modal { position: fixed; top: 0; left: 0; width: 100vw; height: 100vh; background-color: rgba(0,0,0,0.8); z-index: 9999; backdrop-filter: blur(5px); display: flex; justify-content: center; align-items: center; transition: opacity 0.2s ease-in-out; } #custom-api-modal.hidden { display: none !important; opacity: 0; pointer-events: none; } .custom-modal-content { background-color: #1F2937; padding: 30px; border: 1px solid #374151; border-radius: 12px; width: 90%; max-width: 600px; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5); color: #F3F4F6; position: relative; animation: popIn 0.3s ease-out forwards; } /* Styled HTML Buttons to look like Gradio */ .custom-btn { padding: 8px 16px; border-radius: 8px; font-weight: 600; cursor: pointer; border: none; transition: background-color 0.2s, transform 0.1s; } .cutom-btn:active { transform: scale(0.95); } .btn-close { background: transparent; color: #9CA3AF; font-size: 1.2rem; position: absolute; top: 20px; right: 20px; } .btn-close:hover { color: #F3F4F6; } .btn-primary { background-color: #F97316; color: white; width: 100%; margin-top: 20px; padding: 10px; } .btn-primary:hover { background-color: #EA580C; } .input-group { display: flex; gap: 8px; background: #111827; padding: 8px; border-radius: 8px; border: 1px solid #374151; align-items: center; margin-bottom: 15px; } .code-text { flex-grow: 1; font-family: monospace; color: #F472B6; background: transparent; border: none; outline: none; overflow-x: auto; white-space: nowrap; } .btn-copy-small { background: #374151; color: #E5E7EB; padding: 6px 12px; font-size: 0.85rem; } .btn-copy-small:hover { background: #4B5563; } code { background-color: #111827; padding: 2px 5px; border-radius: 4px; color: #F472B6; font-family: monospace; } pre { background-color: #111827; padding: 15px; border-radius: 8px; overflow-x: auto; color: #D1D5DB; border: 1px solid #374151; } """ modal_html_content = """ """ with gr.Blocks(title="CogniCaption") as app: gr.HTML(modal_html_content) with gr.Column(elem_classes=["container"]): gr.HTML("""

CogniCaption

""") with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image", elem_id="input_image") submit_btn = gr.Button("Generate Caption", variant="primary", size="lg") gr.HTML("
") with gr.Column(): output_display = gr.HTML(label="Result", value=format_initial_html()) hidden_caption_storage = gr.Textbox(visible=False, elem_id="hidden_caption_output") with gr.Row(): copy_btn = gr.Button("Copy Caption", size="sm", variant="secondary") api_open_btn = gr.Button("Use CogniCaption As API", size="sm", variant="secondary") submit_btn.click( fn=predict, inputs=[input_image], outputs=[output_display, hidden_caption_storage, copy_btn] ) copy_btn.click( fn=None, inputs=[hidden_caption_storage], outputs=None, js="(text) => copyToClipboard(text)" ) api_open_btn.click( fn=None, inputs=None, outputs=None, js="openModal" ) if __name__ == "__main__": app.launch(css=custom_css, head=js_head, ssr_mode=False)