SigLIP + MiniLM Unanswerable-Question Classifier
A trained fusion head that combines a frozen SigLIP image embedding with a frozen MiniLM text embedding to predict whether a question is unanswerable from a given document image (DocVQA-style unanswerable question detection).
Update: the fusion head now uses bidirectional cross-attention instead of plain concatenation. On the project's fixed 2,158-sample test split this raised F1 from 0.695 to 0.821 and MCC from 0.404 to 0.645 (see Results below). The previous concat-based head is superseded by this checkpoint.
Files
siglip_classifier_head.pt— trainedstate_dictforCrossAttentionHead(image_dim=1152, text_dim=384, proj_dim=256, num_heads=4, hidden_dims=(256, 128)).siglip_embeddings.pt— cached SigLIP image / MiniLM text embeddings used during training (for reproducing splits/evaluation only; not required for inference).
Architecture
Both embeddings are linearly projected into a shared 256-d space and treated as length-1
sequences. A text→image attention head lets the question query the document representation,
and a symmetric image→text head lets the document representation query the question, each
followed by a residual connection and LayerNorm (transformer-block style). The two attended
vectors are concatenated (512-d) and passed through a small MLP classifier.
import torch
from torch import nn
class CrossAttentionHead(nn.Module):
def __init__(self, image_dim=1152, text_dim=384, proj_dim=256, num_heads=4, hidden_dims=(256, 128)):
super().__init__()
self.image_proj = nn.Linear(image_dim, proj_dim)
self.text_proj = nn.Linear(text_dim, proj_dim)
self.text_to_image_attn = nn.MultiheadAttention(proj_dim, num_heads, batch_first=True)
self.image_to_text_attn = nn.MultiheadAttention(proj_dim, num_heads, batch_first=True)
self.norm_text = nn.LayerNorm(proj_dim)
self.norm_image = nn.LayerNorm(proj_dim)
h1, h2 = hidden_dims
self.classifier = nn.Sequential(
nn.Linear(proj_dim * 2, h1),
nn.ReLU(),
nn.Linear(h1, h2),
nn.ReLU(),
nn.Linear(h2, 1),
)
def forward(self, image_embed, text_embed):
img = self.image_proj(image_embed).unsqueeze(1)
txt = self.text_proj(text_embed).unsqueeze(1)
text_attended, _ = self.text_to_image_attn(query=txt, key=img, value=img)
image_attended, _ = self.image_to_text_attn(query=img, key=txt, value=txt)
text_fused = self.norm_text(txt + text_attended).squeeze(1)
image_fused = self.norm_image(img + image_attended).squeeze(1)
fused = torch.cat([text_fused, image_fused], dim=-1)
return self.classifier(fused).squeeze(-1)
Usage
import torch
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("giacolees/siglip-doc-understanding-classifier", "siglip_classifier_head.pt")
head = CrossAttentionHead()
head.load_state_dict(torch.load(ckpt, map_location="cpu"))
head.eval()
# image_embed: SigLIP pooled image features (1152-d), via
# AutoModel.from_pretrained("google/siglip-so400m-patch14-384").get_image_features(...)
# text_embed: MiniLM sentence embedding (384-d), via
# SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").encode(question)
with torch.no_grad():
logit = head(image_embed.unsqueeze(0), text_embed.unsqueeze(0))
prob_unanswerable = torch.sigmoid(logit).item()
Encoders used to produce the embeddings:
- Image:
google/siglip-so400m-patch14-384 - Text:
sentence-transformers/all-MiniLM-L6-v2
Results
Held-out test split (n=2,158), trained/evaluated on the identical embedding cache and train/val/test split as the earlier concat-based head:
| Fusion | Acc | Prec | Rec | F1 | MCC | nlp_entity F1 | element F1 | layout F1 |
|---|---|---|---|---|---|---|---|---|
| Concat (superseded) | 0.702 | 0.711 | 0.678 | 0.695 | 0.404 | 0.738 | 0.509 | 0.336 |
| Cross-attention (this checkpoint) | 0.823 | 0.828 | 0.815 | 0.821 | 0.645 | 0.865 | 0.693 | 0.318 |
Trained as part of the multimodalDocumentUnderstanding benchmark
(unanswerable question detection from document images).