Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from transformers import ViTModel | |
| from PIL import Image | |
| import pickle | |
| import re | |
| import os | |
| class Vocabulary: | |
| def __init__(self, freq_threshold=5): | |
| self.freq_threshold = freq_threshold | |
| self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"} | |
| self.stoi = {v: k for k, v in self.itos.items()} | |
| self.index = 4 | |
| def __len__(self): | |
| return len(self.itos) | |
| def tokenizer(self, text): | |
| text = text.lower() | |
| tokens = re.findall(r"\w+", text) | |
| return tokens | |
| def numericalize(self, text): | |
| tokens = self.tokenizer(text) | |
| numericalized = [] | |
| for token in tokens: | |
| if token in self.stoi: | |
| numericalized.append(self.stoi[token]) | |
| else: | |
| numericalized.append(self.stoi["<UNK>"]) | |
| return numericalized | |
| class Encoder(nn.Module): | |
| def __init__(self, embed_dim, freeze=False): | |
| super().__init__() | |
| self.vit = ViTModel.from_pretrained("facebook/vit-mae-base") | |
| if freeze: | |
| for param in self.vit.parameters(): | |
| param.requires_grad = False | |
| self.linear = nn.Sequential( | |
| nn.Linear(self.vit.config.hidden_size, embed_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(embed_dim, embed_dim), | |
| nn.LayerNorm(embed_dim) | |
| ) | |
| def forward(self, images): | |
| outputs = self.vit(pixel_values=images) | |
| patch_embeddings = outputs.last_hidden_state[:, 1:, :] | |
| features = self.linear(patch_embeddings) | |
| return features | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, hidden_dim, encoder_dim, num_heads=4): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.hidden_dim = hidden_dim | |
| self.head_dim = hidden_dim // num_heads | |
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" | |
| self.query = nn.Linear(hidden_dim, hidden_dim) | |
| self.key = nn.Linear(encoder_dim, hidden_dim) | |
| self.value = nn.Linear(encoder_dim, hidden_dim) | |
| self.fc_out = nn.Linear(hidden_dim, encoder_dim) | |
| def forward(self, hidden, encoder_outputs): | |
| B, N, _ = encoder_outputs.shape | |
| Q = self.query(hidden).view(B, self.num_heads, self.head_dim) | |
| K = self.key(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) | |
| V = self.value(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(Q.unsqueeze(2), K.transpose(-2, -1)) / (self.head_dim ** 0.5) | |
| attn = torch.softmax(scores, dim=-1) | |
| context = torch.matmul(attn, V) | |
| context = context.transpose(1, 2).contiguous().view(B, self.hidden_dim) | |
| return self.fc_out(context) | |
| class Decoder(nn.Module): | |
| def __init__(self, embed_dim, hidden_dim, vocab_size, encoder_dim=256, num_layers=2, dropout=0.3, num_heads=4): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.lstm = nn.LSTM(embed_dim + encoder_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0) | |
| self.attention = MultiHeadAttention(hidden_dim, encoder_dim, num_heads=num_heads) | |
| self.fc = nn.Linear(hidden_dim, vocab_size) | |
| def generate(self, features, max_len=50, start_index=1, end_index=2, beam_size=3, beam_search=True): | |
| B = features.size(0) | |
| device = features.device | |
| states = (torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device), | |
| torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device)) | |
| if not beam_search: | |
| generated = [] | |
| current_token = torch.LongTensor([start_index]).to(device).unsqueeze(0) | |
| for _ in range(max_len): | |
| emb = self.embedding(current_token).squeeze(1) | |
| context = self.attention(states[0][-1], features) | |
| lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1) | |
| out, states = self.lstm(lstm_input, states) | |
| logits = self.fc(out.squeeze(1)) | |
| predicted = logits.argmax(dim=1).item() | |
| generated.append(predicted) | |
| if predicted == end_index: break | |
| current_token = torch.LongTensor([predicted]).to(device).unsqueeze(0) | |
| return generated | |
| else: | |
| beams = [([start_index], 0.0, states) for _ in range(beam_size)] | |
| for _ in range(max_len): | |
| new_beams = [] | |
| for seq, log_prob, (h, c) in beams: | |
| current_token = torch.LongTensor([seq[-1]]).to(device).unsqueeze(0) | |
| emb = self.embedding(current_token).squeeze(1) | |
| context = self.attention(h[-1], features) | |
| lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1) | |
| out, (h_new, c_new) = self.lstm(lstm_input, (h, c)) | |
| logits = self.fc(out.squeeze(1)) | |
| log_probs = torch.log_softmax(logits, dim=1) | |
| top_log_probs, top_indices = log_probs.topk(beam_size, dim=1) | |
| for k in range(beam_size): | |
| next_seq = seq + [top_indices[0, k].item()] | |
| next_log_prob = log_prob + top_log_probs[0, k].item() | |
| new_beams.append((next_seq, next_log_prob, (h_new, c_new))) | |
| new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size] | |
| beams = new_beams | |
| if all(seq[-1] == end_index for seq, _, _ in beams): break | |
| best_seq = beams[0][0] | |
| if best_seq[0] == start_index: best_seq = best_seq[1:] | |
| return best_seq | |
| class Model(nn.Module): | |
| def __init__(self, encoder, decoder): | |
| super().__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| def generate(self, images, max_len=50): | |
| features = self.encoder(images) | |
| captions = self.decoder.generate(features, max_len=max_len, beam_search=True) | |
| return captions | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| EMBED_DIM = 256 | |
| HIDDEN_DIM = 512 | |
| VOCAB_PATH = "vocab-v4.pkl" | |
| MODEL_PATH = "vit_lstm_best-v4.pth" | |
| vocab = None | |
| model = None | |
| inference_transform = None | |
| def load_system(): | |
| global vocab, model, inference_transform | |
| print("Loading Vocabulary...") | |
| try: | |
| with open(VOCAB_PATH, "rb") as f: | |
| vocab = pickle.load(f) | |
| except Exception as e: | |
| return f"Error loading vocab: {e}" | |
| print("Initializing Model...") | |
| encoder = Encoder(EMBED_DIM, freeze=True) | |
| decoder = Decoder(EMBED_DIM, HIDDEN_DIM, len(vocab)) | |
| model = Model(encoder, decoder).to(DEVICE) | |
| print("Loading Weights...") | |
| try: | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| except Exception as e: | |
| return f"Error loading model weights: {e}" | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return "System Loaded" | |
| load_status = load_system() | |
| def format_loading_html(): | |
| return """ | |
| <div class="loading-box" style=" | |
| text-align: center; | |
| padding: 40px; | |
| border: 2px solid #6B7280; | |
| border-radius: 15px; | |
| background-color: #27272A; | |
| "> | |
| <h3 style="color: #6B7280; margin-bottom: 5px;">Status</h3> | |
| <h2 style="color: #F3F4F6; font-size: 24px; margin: 10px 0;">Analyzing Image...</h2> | |
| <div style="font-size: 48px; font-weight: bold; color: #6B7280;"> | |
| ... | |
| </div> | |
| <p style="color: #9CA3AF; font-weight: bold; margin-top: 5px;">Please wait...</p> | |
| </div> | |
| """ | |
| def format_result_html(caption): | |
| return f""" | |
| <div class="result-animation" style=" | |
| text-align: center; | |
| padding: 30px; | |
| border: 2px solid #4F46E5; | |
| border-radius: 15px; | |
| background-color: #27272A; | |
| box-shadow: 0 4px 6px -1px rgba(79, 70, 229, 0.1); | |
| "> | |
| <h3 style="color: #818CF8; margin-bottom: 10px; text-transform: uppercase; letter-spacing: 2px;">Generated Caption</h3> | |
| <div style=" | |
| font-size: 28px; | |
| font-weight: bold; | |
| color: #F9FAFB; | |
| margin: 20px 0; | |
| line-height: 1.4; | |
| "> | |
| "{caption}" | |
| </div> | |
| </div> | |
| """ | |
| def format_initial_html(): | |
| return """ | |
| <div style=" | |
| text-align: center; | |
| padding: 40px; | |
| border: 2px dashed #4B5563; | |
| border-radius: 15px; | |
| background-color: #27272A; | |
| color: #9CA3AF; | |
| "> | |
| <h3>Output Area</h3> | |
| <p>Your generated caption will appear here.</p> | |
| </div> | |
| """ | |
| def format_error_html(error): | |
| return f""" | |
| <div style="text-align: center; padding: 20px; border: 2px solid #EF4444; border-radius: 15px; background-color: #450A0A;"> | |
| <h3 style="color: #F87171;">Error</h3> | |
| <p style="color: #FECACA;">{error}</p> | |
| </div> | |
| """ | |
| def predict(image): | |
| if image is None: | |
| yield format_error_html("No image uploaded"), "", gr.update(variant="secondary") | |
| return | |
| yield format_loading_html(), "", gr.update(variant="secondary") | |
| try: | |
| pil_image = image.convert("RGB") | |
| image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output_indices = model.generate(image_tensor) | |
| result_words = [] | |
| for idx in output_indices: | |
| word = vocab.itos.get(idx, "<UNK>") | |
| if word == "<EOS>": break | |
| if word not in ("<SOS>", "<PAD>"): | |
| result_words.append(word) | |
| caption = " ".join(result_words) | |
| yield format_result_html(caption), caption, gr.update(variant="primary") | |
| except Exception as e: | |
| yield format_error_html(str(e)), "", gr.update(variant="secondary") | |
| js_head = """ | |
| <script> | |
| function showToast(message, type) { | |
| const toast = document.createElement("div"); | |
| toast.className = "toast"; | |
| toast.innerText = message; | |
| if (type === 'error') { | |
| toast.style.backgroundColor = "#EF4444"; // Red | |
| toast.style.color = "#FFFFFF"; | |
| } else { | |
| toast.style.backgroundColor = "#10B981"; // Green | |
| toast.style.color = "#FFFFFF"; | |
| } | |
| document.body.appendChild(toast); | |
| setTimeout(() => { | |
| toast.classList.add('hiding'); | |
| toast.addEventListener('animationend', () => toast.remove()); | |
| }, 2500); | |
| } | |
| function copyToClipboard(text) { | |
| if (!text) { | |
| showToast("No caption to copy!", "error"); | |
| return; | |
| } | |
| navigator.clipboard.writeText(text).then(function() { | |
| showToast("Caption Copied!", "success"); | |
| }, function(err) { | |
| showToast("Failed to copy", "error"); | |
| console.error('Async: Could not copy text: ', err); | |
| }); | |
| } | |
| function openModal() { | |
| const modal = document.getElementById("custom-api-modal"); | |
| modal.classList.remove("hidden"); | |
| } | |
| function closeModal() { | |
| const modal = document.getElementById("custom-api-modal"); | |
| modal.classList.add("hidden"); | |
| } | |
| document.addEventListener("click", function(e) { | |
| if (e.target.classList.contains("modal-container")) { | |
| document.querySelector("button[aria-label='close-modal']").click(); | |
| } | |
| }); | |
| </script> | |
| """ | |
| custom_css = """ | |
| body { background-color: #111827; } | |
| @keyframes fadeInUp { | |
| from { opacity: 0; transform: translateY(20px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| .container { | |
| max-width: 800px; | |
| margin: auto; | |
| padding-top: 20px; | |
| } | |
| .header { text-align: center; margin-bottom: 30px; } | |
| .header h1 { color: #818CF8; font-size: 2.5rem; } | |
| .header p { color: #9CA3AF; } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.5; } | |
| } | |
| .loading-box { animation: pulse 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite; } | |
| button { | |
| transition: all 0.2s ease-in-out !important; | |
| } | |
| button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3); | |
| } | |
| button:active { | |
| transform: translateY(0); | |
| } | |
| @keyframes popIn { | |
| 0% { opacity: 0; transform: scale(0.9); } | |
| 70% { transform: scale(1.02); } | |
| 100% { opacity: 1; transform: scale(1); } | |
| } | |
| .result-animation { | |
| animation: popIn 0.5s cubic-bezier(0.175, 0.885, 0.32, 1.275) forwards; | |
| } | |
| @keyframes slideUpFadeIn { | |
| from { opacity: 0; transform: translate(-50%, 100%); } | |
| to { opacity: 1; transform: translate(-50%, 0); } | |
| } | |
| @keyframes fadeOutSlideDown { | |
| from { opacity: 1; transform: translate(-50%, 0); } | |
| to { opacity: 0; transform: translate(-50%, 100%); } | |
| } | |
| .toast { | |
| position: fixed; | |
| bottom: 30px; | |
| left: 50%; | |
| transform: translate(-50%, 0); | |
| padding: 12px 24px; | |
| border-radius: 8px; | |
| z-index: 10000; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15); | |
| font-weight: 500; | |
| animation: slideUpFadeIn 0.5s ease forwards; | |
| } | |
| .toast.hiding { | |
| animation: fadeOutSlideDown 0.5s ease forwards; | |
| } | |
| #custom-api-modal { | |
| position: fixed; | |
| top: 0; | |
| left: 0; | |
| width: 100vw; | |
| height: 100vh; | |
| background-color: rgba(0,0,0,0.8); | |
| z-index: 9999; | |
| backdrop-filter: blur(5px); | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| transition: opacity 0.2s ease-in-out; | |
| } | |
| #custom-api-modal.hidden { | |
| display: none !important; | |
| opacity: 0; | |
| pointer-events: none; | |
| } | |
| .custom-modal-content { | |
| background-color: #1F2937; | |
| padding: 30px; | |
| border: 1px solid #374151; | |
| border-radius: 12px; | |
| width: 90%; | |
| max-width: 600px; | |
| box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5); | |
| color: #F3F4F6; | |
| position: relative; | |
| animation: popIn 0.3s ease-out forwards; | |
| } | |
| /* Styled HTML Buttons to look like Gradio */ | |
| .custom-btn { | |
| padding: 8px 16px; | |
| border-radius: 8px; | |
| font-weight: 600; | |
| cursor: pointer; | |
| border: none; | |
| transition: background-color 0.2s, transform 0.1s; | |
| } | |
| .cutom-btn:active { transform: scale(0.95); } | |
| .btn-close { | |
| background: transparent; | |
| color: #9CA3AF; | |
| font-size: 1.2rem; | |
| position: absolute; | |
| top: 20px; | |
| right: 20px; | |
| } | |
| .btn-close:hover { color: #F3F4F6; } | |
| .btn-primary { | |
| background-color: #F97316; | |
| color: white; | |
| width: 100%; | |
| margin-top: 20px; | |
| padding: 10px; | |
| } | |
| .btn-primary:hover { | |
| background-color: #EA580C; | |
| } | |
| .input-group { | |
| display: flex; | |
| gap: 8px; | |
| background: #111827; | |
| padding: 8px; | |
| border-radius: 8px; | |
| border: 1px solid #374151; | |
| align-items: center; | |
| margin-bottom: 15px; | |
| } | |
| .code-text { | |
| flex-grow: 1; | |
| font-family: monospace; | |
| color: #F472B6; | |
| background: transparent; | |
| border: none; | |
| outline: none; | |
| overflow-x: auto; | |
| white-space: nowrap; | |
| } | |
| .btn-copy-small { | |
| background: #374151; | |
| color: #E5E7EB; | |
| padding: 6px 12px; | |
| font-size: 0.85rem; | |
| } | |
| .btn-copy-small:hover { | |
| background: #4B5563; | |
| } | |
| code { | |
| background-color: #111827; | |
| padding: 2px 5px; | |
| border-radius: 4px; | |
| color: #F472B6; | |
| font-family: monospace; | |
| } | |
| pre { | |
| background-color: #111827; | |
| padding: 15px; | |
| border-radius: 8px; | |
| overflow-x: auto; | |
| color: #D1D5DB; | |
| border: 1px solid #374151; | |
| } | |
| """ | |
| modal_html_content = """ | |
| <div id="custom-api-modal" class="hidden"> | |
| <div class="custom-modal-content"> | |
| <button class="custom-btn btn-close" onclick="closeModal()">✕</button> | |
| <h2 style="margin-top:0; color: #818CF8;">Use CogniCaption as API</h2> | |
| <hr style="border-color: #374151; margin: 15px 0;"> | |
| <p>You can use this Hugging Face Space as an API via the <code>gradio_client</code>.</p> | |
| <h4>1. API Endpoint</h4> | |
| <div class="input-group"> | |
| <div class="code-text">https://huggingface.co/spaces/Briran/CogniCaption</div> | |
| <button class="custom-btn btn-copy-small" onclick="copyToClipboard('https://huggingface.co/spaces/Briran/CogniCaption')">Copy</button> | |
| </div> | |
| <h4>2. How to Request</h4> | |
| <p style="font-size:0.9rem; color:#9CA3AF;">Send an image to the <code>predict</code> endpoint.</p> | |
| <pre> | |
| from gradio_client import Client | |
| client = Client("Briran/CogniCaption") | |
| result = client.predict( | |
| image="{INSERT YOUR IMAGE HERE}", | |
| api_name="/predict" | |
| ) | |
| print(result)</pre> | |
| <button class="custom-btn btn-primary" onclick="closeModal()">OK</button> | |
| </div> | |
| </div> | |
| """ | |
| with gr.Blocks(title="CogniCaption") as app: | |
| gr.HTML(modal_html_content) | |
| with gr.Column(elem_classes=["container"]): | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>CogniCaption</h1> | |
| </div> | |
| """) | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Upload Image", elem_id="input_image") | |
| submit_btn = gr.Button("Generate Caption", variant="primary", size="lg") | |
| gr.HTML("<hr style='border-color: #374151; margin: 30px 0;'>") | |
| with gr.Column(): | |
| output_display = gr.HTML(label="Result", value=format_initial_html()) | |
| hidden_caption_storage = gr.Textbox(visible=False, elem_id="hidden_caption_output") | |
| with gr.Row(): | |
| copy_btn = gr.Button("Copy Caption", size="sm", variant="secondary") | |
| api_open_btn = gr.Button("Use CogniCaption As API", size="sm", variant="secondary") | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_image], | |
| outputs=[output_display, hidden_caption_storage, copy_btn] | |
| ) | |
| copy_btn.click( | |
| fn=None, | |
| inputs=[hidden_caption_storage], | |
| outputs=None, | |
| js="(text) => copyToClipboard(text)" | |
| ) | |
| api_open_btn.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js="openModal" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(css=custom_css, head=js_head, ssr_mode=False) |