import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os import gradio as gr from torchvision import transforms from PIL import Image from transformers import CLIPTokenizer from huggingface_hub import hf_hub_download # ============================================================================ # CONFIGURATION - HARDCODED FOR I3-CLIP ARCHITECTURE # ============================================================================ D_MODEL = 768 N_RWKV = 12 N_ATTN = 4 N_HEADS = 12 FFN_MULT = 4 MAX_LEN = 77 # ============================================================================ # 1. RWKV CORE (JIT OPTIMIZED) # ============================================================================ @torch.jit.script def rwkv_linear_attention(B: int, T: int, C: int, r: torch.Tensor, k: torch.Tensor, v: torch.Tensor, w: torch.Tensor, u: torch.Tensor, state_init: torch.Tensor): y = torch.zeros_like(v) state_aa = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_bb = torch.zeros(B, C, dtype=torch.float32, device=r.device) state_pp = state_init.clone() for t in range(T): rt, kt, vt = r[:, t], k[:, t], v[:, t] ww = u + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) wkv = (state_aa * e1 + vt * e2) / (state_bb * e1 + e2 + 1e-6) y[:, t] = wkv ww = w + state_pp p = torch.maximum(ww, kt) e1 = torch.exp(ww - p) e2 = torch.exp(kt - p) state_aa = state_aa * e1 + vt * e2 state_bb = state_bb * e1 + e2 state_pp = p return y class RWKVTimeMix(nn.Module): def __init__(self, d_model): super().__init__() self.d_model = d_model self.time_decay = nn.Parameter(torch.ones(d_model)) self.time_first = nn.Parameter(torch.ones(d_model)) self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_v = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) self.key = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(d_model, d_model, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.output = nn.Linear(d_model, d_model, bias=False) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k, v = self.key(xk), self.value(xv) r = torch.sigmoid(self.receptance(xr)) w, u = -torch.exp(self.time_decay), self.time_first state_init = torch.full((B, C), -1e30, dtype=torch.float32, device=x.device) rwkv = rwkv_linear_attention(B, T, C, r, k, v, w, u, state_init) return self.output(r * rwkv) class RWKVChannelMix(nn.Module): def __init__(self, d_model, ffn_mult=4): super().__init__() self.time_mix_k = nn.Parameter(torch.ones(1, 1, d_model)) self.time_mix_r = nn.Parameter(torch.ones(1, 1, d_model)) hidden_sz = d_model * ffn_mult self.key = nn.Linear(d_model, hidden_sz, bias=False) self.receptance = nn.Linear(d_model, d_model, bias=False) self.value = nn.Linear(hidden_sz, d_model, bias=False) def forward(self, x): B, T, C = x.size() xx = torch.cat([torch.zeros((B, 1, C), device=x.device), x[:, :-1]], dim=1) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = torch.square(torch.relu(self.key(xk))) return torch.sigmoid(self.receptance(xr)) * self.value(k) class RWKVBlock(nn.Module): def __init__(self, d_model): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.att = RWKVTimeMix(d_model) self.ln2 = nn.LayerNorm(d_model) self.ffn = RWKVChannelMix(d_model) def forward(self, x): x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) return x # ============================================================================ # 2. VISION ENCODER # ============================================================================ class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.downsample = nn.Identity() if stride != 1 or in_planes != self.expansion * planes: self.downsample = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.downsample(x) out = F.relu(out) return out class VisionEncoderLarge(nn.Module): def __init__(self, d_model=768): super().__init__() self.stem = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(3, 2, 1) ) self.layer1 = self._make_layer(64, 64, 3) self.layer2 = self._make_layer(256, 128, 4, stride=2) self.layer3 = self._make_layer(512, 256, 6, stride=2) self.layer4 = self._make_layer(1024, 512, 3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(2048, d_model) def _make_layer(self, in_planes, planes, blocks, stride=1): layers = [Bottleneck(in_planes, planes, stride)] for _ in range(1, blocks): layers.append(Bottleneck(planes * 4, planes)) return nn.Sequential(*layers) def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return self.fc(self.avgpool(x).flatten(1)) # ============================================================================ # 3. TEXT ENCODER # ============================================================================ class HybridTextEncoderLarge(nn.Module): def __init__(self, vocab_size, d_model=768, n_rwkv=12, n_attn=4, max_len=77): super().__init__() self.token_embed = nn.Embedding(vocab_size, d_model) self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model)) self.rwkv_layers = nn.ModuleList([RWKVBlock(d_model) for _ in range(n_rwkv)]) self.attn_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=d_model, nhead=N_HEADS, dim_feedforward=d_model*4, batch_first=True, activation="gelu") for _ in range(n_attn) ]) self.ln_final = nn.LayerNorm(d_model) def forward(self, x): x = self.token_embed(x) + self.pos_embed[:, :x.size(1), :] for layer in self.rwkv_layers: x = layer(x) for layer in self.attn_layers: x = layer(x) return self.ln_final(x[:, -1, :]) # ============================================================================ # 4. WRAPPER # ============================================================================ class i3CLIPHybridLarge(nn.Module): def __init__(self, vocab_size, d_model=768): super().__init__() self.visual = VisionEncoderLarge(d_model=d_model) self.textual = HybridTextEncoderLarge(vocab_size=vocab_size, d_model=d_model) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward(self, images, texts): img_f = F.normalize(self.visual(images), dim=-1) txt_f = F.normalize(self.textual(texts), dim=-1) scale = self.logit_scale.exp() logits = scale * img_f @ txt_f.t() return logits, logits.t() # ============================================================================ # 5. INFERENCE LOGIC # ============================================================================ device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") model = i3CLIPHybridLarge(tokenizer.vocab_size).to(device) # Load checkpoint from FlameF0X/i3-CLIP print("Downloading and loading model weights...") checkpoint_path = hf_hub_download(repo_id="i3-lab/i3-CLIP", filename="pytorch_model.bin") state_dict = torch.load(checkpoint_path, map_location=device) model.load_state_dict(state_dict, strict=False) model.eval() preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.48, 0.45, 0.40), (0.26, 0.26, 0.27)) ]) def predict(image, labels_text): if image is None: return None labels = [l.strip() for l in labels_text.split(",")] # Process image img_tensor = preprocess(image).unsqueeze(0).to(device) # Process text txt_tokens = tokenizer( labels, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt" ).input_ids.to(device) with torch.no_grad(): img_features = F.normalize(model.visual(img_tensor), dim=-1) txt_features = F.normalize(model.textual(txt_tokens), dim=-1) logits = (img_features @ txt_features.t()) * model.logit_scale.exp() probs = F.softmax(logits, dim=-1).cpu().numpy()[0] return {labels[i]: float(probs[i]) for i in range(len(labels))} # ============================================================================ # 6. GRADIO INTERFACE # ============================================================================ demo = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Candidate Labels (comma separated)", value="a photo of a cat, a photo of a dog, a landscape") ], outputs=gr.Label(num_top_classes=5), title="i3-CLIP Hybrid RWKV-Transformer Large", description="This space uses the i3-CLIP architecture: a ResNet-like Bottleneck Vision Encoder and a Hybrid RWKV-Transformer Text Encoder. Weights are loaded from FlameF0X/i3-CLIP." ) if __name__ == "__main__": demo.launch()