CogniCaption / app.py
BriranSus
feat: add animations
99bb05a
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import ViTModel
from PIL import Image
import pickle
import re
import os
class Vocabulary:
def __init__(self, freq_threshold=5):
self.freq_threshold = freq_threshold
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
self.stoi = {v: k for k, v in self.itos.items()}
self.index = 4
def __len__(self):
return len(self.itos)
def tokenizer(self, text):
text = text.lower()
tokens = re.findall(r"\w+", text)
return tokens
def numericalize(self, text):
tokens = self.tokenizer(text)
numericalized = []
for token in tokens:
if token in self.stoi:
numericalized.append(self.stoi[token])
else:
numericalized.append(self.stoi["<UNK>"])
return numericalized
class Encoder(nn.Module):
def __init__(self, embed_dim, freeze=False):
super().__init__()
self.vit = ViTModel.from_pretrained("facebook/vit-mae-base")
if freeze:
for param in self.vit.parameters():
param.requires_grad = False
self.linear = nn.Sequential(
nn.Linear(self.vit.config.hidden_size, embed_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(embed_dim, embed_dim),
nn.LayerNorm(embed_dim)
)
def forward(self, images):
outputs = self.vit(pixel_values=images)
patch_embeddings = outputs.last_hidden_state[:, 1:, :]
features = self.linear(patch_embeddings)
return features
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim, encoder_dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // num_heads
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
self.query = nn.Linear(hidden_dim, hidden_dim)
self.key = nn.Linear(encoder_dim, hidden_dim)
self.value = nn.Linear(encoder_dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, encoder_dim)
def forward(self, hidden, encoder_outputs):
B, N, _ = encoder_outputs.shape
Q = self.query(hidden).view(B, self.num_heads, self.head_dim)
K = self.key(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
V = self.value(encoder_outputs).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
scores = torch.matmul(Q.unsqueeze(2), K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, V)
context = context.transpose(1, 2).contiguous().view(B, self.hidden_dim)
return self.fc_out(context)
class Decoder(nn.Module):
def __init__(self, embed_dim, hidden_dim, vocab_size, encoder_dim=256, num_layers=2, dropout=0.3, num_heads=4):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(embed_dim + encoder_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
self.attention = MultiHeadAttention(hidden_dim, encoder_dim, num_heads=num_heads)
self.fc = nn.Linear(hidden_dim, vocab_size)
def generate(self, features, max_len=50, start_index=1, end_index=2, beam_size=3, beam_search=True):
B = features.size(0)
device = features.device
states = (torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device),
torch.zeros(self.lstm.num_layers, B, self.lstm.hidden_size, device=device))
if not beam_search:
generated = []
current_token = torch.LongTensor([start_index]).to(device).unsqueeze(0)
for _ in range(max_len):
emb = self.embedding(current_token).squeeze(1)
context = self.attention(states[0][-1], features)
lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1)
out, states = self.lstm(lstm_input, states)
logits = self.fc(out.squeeze(1))
predicted = logits.argmax(dim=1).item()
generated.append(predicted)
if predicted == end_index: break
current_token = torch.LongTensor([predicted]).to(device).unsqueeze(0)
return generated
else:
beams = [([start_index], 0.0, states) for _ in range(beam_size)]
for _ in range(max_len):
new_beams = []
for seq, log_prob, (h, c) in beams:
current_token = torch.LongTensor([seq[-1]]).to(device).unsqueeze(0)
emb = self.embedding(current_token).squeeze(1)
context = self.attention(h[-1], features)
lstm_input = torch.cat((emb, context), dim=1).unsqueeze(1)
out, (h_new, c_new) = self.lstm(lstm_input, (h, c))
logits = self.fc(out.squeeze(1))
log_probs = torch.log_softmax(logits, dim=1)
top_log_probs, top_indices = log_probs.topk(beam_size, dim=1)
for k in range(beam_size):
next_seq = seq + [top_indices[0, k].item()]
next_log_prob = log_prob + top_log_probs[0, k].item()
new_beams.append((next_seq, next_log_prob, (h_new, c_new)))
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
beams = new_beams
if all(seq[-1] == end_index for seq, _, _ in beams): break
best_seq = beams[0][0]
if best_seq[0] == start_index: best_seq = best_seq[1:]
return best_seq
class Model(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def generate(self, images, max_len=50):
features = self.encoder(images)
captions = self.decoder.generate(features, max_len=max_len, beam_search=True)
return captions
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EMBED_DIM = 256
HIDDEN_DIM = 512
VOCAB_PATH = "vocab-v4.pkl"
MODEL_PATH = "vit_lstm_best-v4.pth"
vocab = None
model = None
inference_transform = None
def load_system():
global vocab, model, inference_transform
print("Loading Vocabulary...")
try:
with open(VOCAB_PATH, "rb") as f:
vocab = pickle.load(f)
except Exception as e:
return f"Error loading vocab: {e}"
print("Initializing Model...")
encoder = Encoder(EMBED_DIM, freeze=True)
decoder = Decoder(EMBED_DIM, HIDDEN_DIM, len(vocab))
model = Model(encoder, decoder).to(DEVICE)
print("Loading Weights...")
try:
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
except Exception as e:
return f"Error loading model weights: {e}"
inference_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return "System Loaded"
load_status = load_system()
def format_loading_html():
return """
<div class="loading-box" style="
text-align: center;
padding: 40px;
border: 2px solid #6B7280;
border-radius: 15px;
background-color: #27272A;
">
<h3 style="color: #6B7280; margin-bottom: 5px;">Status</h3>
<h2 style="color: #F3F4F6; font-size: 24px; margin: 10px 0;">Analyzing Image...</h2>
<div style="font-size: 48px; font-weight: bold; color: #6B7280;">
...
</div>
<p style="color: #9CA3AF; font-weight: bold; margin-top: 5px;">Please wait...</p>
</div>
"""
def format_result_html(caption):
return f"""
<div class="result-animation" style="
text-align: center;
padding: 30px;
border: 2px solid #4F46E5;
border-radius: 15px;
background-color: #27272A;
box-shadow: 0 4px 6px -1px rgba(79, 70, 229, 0.1);
">
<h3 style="color: #818CF8; margin-bottom: 10px; text-transform: uppercase; letter-spacing: 2px;">Generated Caption</h3>
<div style="
font-size: 28px;
font-weight: bold;
color: #F9FAFB;
margin: 20px 0;
line-height: 1.4;
">
"{caption}"
</div>
</div>
"""
def format_initial_html():
return """
<div style="
text-align: center;
padding: 40px;
border: 2px dashed #4B5563;
border-radius: 15px;
background-color: #27272A;
color: #9CA3AF;
">
<h3>Output Area</h3>
<p>Your generated caption will appear here.</p>
</div>
"""
def format_error_html(error):
return f"""
<div style="text-align: center; padding: 20px; border: 2px solid #EF4444; border-radius: 15px; background-color: #450A0A;">
<h3 style="color: #F87171;">Error</h3>
<p style="color: #FECACA;">{error}</p>
</div>
"""
def predict(image):
if image is None:
yield format_error_html("No image uploaded"), "", gr.update(variant="secondary")
return
yield format_loading_html(), "", gr.update(variant="secondary")
try:
pil_image = image.convert("RGB")
image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output_indices = model.generate(image_tensor)
result_words = []
for idx in output_indices:
word = vocab.itos.get(idx, "<UNK>")
if word == "<EOS>": break
if word not in ("<SOS>", "<PAD>"):
result_words.append(word)
caption = " ".join(result_words)
yield format_result_html(caption), caption, gr.update(variant="primary")
except Exception as e:
yield format_error_html(str(e)), "", gr.update(variant="secondary")
js_head = """
<script>
function showToast(message, type) {
const toast = document.createElement("div");
toast.className = "toast";
toast.innerText = message;
if (type === 'error') {
toast.style.backgroundColor = "#EF4444"; // Red
toast.style.color = "#FFFFFF";
} else {
toast.style.backgroundColor = "#10B981"; // Green
toast.style.color = "#FFFFFF";
}
document.body.appendChild(toast);
setTimeout(() => {
toast.classList.add('hiding');
toast.addEventListener('animationend', () => toast.remove());
}, 2500);
}
function copyToClipboard(text) {
if (!text) {
showToast("No caption to copy!", "error");
return;
}
navigator.clipboard.writeText(text).then(function() {
showToast("Caption Copied!", "success");
}, function(err) {
showToast("Failed to copy", "error");
console.error('Async: Could not copy text: ', err);
});
}
function openModal() {
const modal = document.getElementById("custom-api-modal");
modal.classList.remove("hidden");
}
function closeModal() {
const modal = document.getElementById("custom-api-modal");
modal.classList.add("hidden");
}
document.addEventListener("click", function(e) {
if (e.target.classList.contains("modal-container")) {
document.querySelector("button[aria-label='close-modal']").click();
}
});
</script>
"""
custom_css = """
body { background-color: #111827; }
@keyframes fadeInUp {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
.container {
max-width: 800px;
margin: auto;
padding-top: 20px;
}
.header { text-align: center; margin-bottom: 30px; }
.header h1 { color: #818CF8; font-size: 2.5rem; }
.header p { color: #9CA3AF; }
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
.loading-box { animation: pulse 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite; }
button {
transition: all 0.2s ease-in-out !important;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.3);
}
button:active {
transform: translateY(0);
}
@keyframes popIn {
0% { opacity: 0; transform: scale(0.9); }
70% { transform: scale(1.02); }
100% { opacity: 1; transform: scale(1); }
}
.result-animation {
animation: popIn 0.5s cubic-bezier(0.175, 0.885, 0.32, 1.275) forwards;
}
@keyframes slideUpFadeIn {
from { opacity: 0; transform: translate(-50%, 100%); }
to { opacity: 1; transform: translate(-50%, 0); }
}
@keyframes fadeOutSlideDown {
from { opacity: 1; transform: translate(-50%, 0); }
to { opacity: 0; transform: translate(-50%, 100%); }
}
.toast {
position: fixed;
bottom: 30px;
left: 50%;
transform: translate(-50%, 0);
padding: 12px 24px;
border-radius: 8px;
z-index: 10000;
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
font-weight: 500;
animation: slideUpFadeIn 0.5s ease forwards;
}
.toast.hiding {
animation: fadeOutSlideDown 0.5s ease forwards;
}
#custom-api-modal {
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
background-color: rgba(0,0,0,0.8);
z-index: 9999;
backdrop-filter: blur(5px);
display: flex;
justify-content: center;
align-items: center;
transition: opacity 0.2s ease-in-out;
}
#custom-api-modal.hidden {
display: none !important;
opacity: 0;
pointer-events: none;
}
.custom-modal-content {
background-color: #1F2937;
padding: 30px;
border: 1px solid #374151;
border-radius: 12px;
width: 90%;
max-width: 600px;
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
color: #F3F4F6;
position: relative;
animation: popIn 0.3s ease-out forwards;
}
/* Styled HTML Buttons to look like Gradio */
.custom-btn {
padding: 8px 16px;
border-radius: 8px;
font-weight: 600;
cursor: pointer;
border: none;
transition: background-color 0.2s, transform 0.1s;
}
.cutom-btn:active { transform: scale(0.95); }
.btn-close {
background: transparent;
color: #9CA3AF;
font-size: 1.2rem;
position: absolute;
top: 20px;
right: 20px;
}
.btn-close:hover { color: #F3F4F6; }
.btn-primary {
background-color: #F97316;
color: white;
width: 100%;
margin-top: 20px;
padding: 10px;
}
.btn-primary:hover {
background-color: #EA580C;
}
.input-group {
display: flex;
gap: 8px;
background: #111827;
padding: 8px;
border-radius: 8px;
border: 1px solid #374151;
align-items: center;
margin-bottom: 15px;
}
.code-text {
flex-grow: 1;
font-family: monospace;
color: #F472B6;
background: transparent;
border: none;
outline: none;
overflow-x: auto;
white-space: nowrap;
}
.btn-copy-small {
background: #374151;
color: #E5E7EB;
padding: 6px 12px;
font-size: 0.85rem;
}
.btn-copy-small:hover {
background: #4B5563;
}
code {
background-color: #111827;
padding: 2px 5px;
border-radius: 4px;
color: #F472B6;
font-family: monospace;
}
pre {
background-color: #111827;
padding: 15px;
border-radius: 8px;
overflow-x: auto;
color: #D1D5DB;
border: 1px solid #374151;
}
"""
modal_html_content = """
<div id="custom-api-modal" class="hidden">
<div class="custom-modal-content">
<button class="custom-btn btn-close" onclick="closeModal()">✕</button>
<h2 style="margin-top:0; color: #818CF8;">Use CogniCaption as API</h2>
<hr style="border-color: #374151; margin: 15px 0;">
<p>You can use this Hugging Face Space as an API via the <code>gradio_client</code>.</p>
<h4>1. API Endpoint</h4>
<div class="input-group">
<div class="code-text">https://huggingface.co/spaces/Briran/CogniCaption</div>
<button class="custom-btn btn-copy-small" onclick="copyToClipboard('https://huggingface.co/spaces/Briran/CogniCaption')">Copy</button>
</div>
<h4>2. How to Request</h4>
<p style="font-size:0.9rem; color:#9CA3AF;">Send an image to the <code>predict</code> endpoint.</p>
<pre>
from gradio_client import Client
client = Client("Briran/CogniCaption")
result = client.predict(
image="{INSERT YOUR IMAGE HERE}",
api_name="/predict"
)
print(result)</pre>
<button class="custom-btn btn-primary" onclick="closeModal()">OK</button>
</div>
</div>
"""
with gr.Blocks(title="CogniCaption") as app:
gr.HTML(modal_html_content)
with gr.Column(elem_classes=["container"]):
gr.HTML("""
<div class="header">
<h1>CogniCaption</h1>
</div>
""")
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", elem_id="input_image")
submit_btn = gr.Button("Generate Caption", variant="primary", size="lg")
gr.HTML("<hr style='border-color: #374151; margin: 30px 0;'>")
with gr.Column():
output_display = gr.HTML(label="Result", value=format_initial_html())
hidden_caption_storage = gr.Textbox(visible=False, elem_id="hidden_caption_output")
with gr.Row():
copy_btn = gr.Button("Copy Caption", size="sm", variant="secondary")
api_open_btn = gr.Button("Use CogniCaption As API", size="sm", variant="secondary")
submit_btn.click(
fn=predict,
inputs=[input_image],
outputs=[output_display, hidden_caption_storage, copy_btn]
)
copy_btn.click(
fn=None,
inputs=[hidden_caption_storage],
outputs=None,
js="(text) => copyToClipboard(text)"
)
api_open_btn.click(
fn=None,
inputs=None,
outputs=None,
js="openModal"
)
if __name__ == "__main__":
app.launch(css=custom_css, head=js_head, ssr_mode=False)