SigLIP + MiniLM Unanswerable-Question Classifier

A trained fusion head that combines a frozen SigLIP image embedding with a frozen MiniLM text embedding to predict whether a question is unanswerable from a given document image (DocVQA-style unanswerable question detection).

Update: the fusion head now uses bidirectional cross-attention instead of plain concatenation. On the project's fixed 2,158-sample test split this raised F1 from 0.695 to 0.821 and MCC from 0.404 to 0.645 (see Results below). The previous concat-based head is superseded by this checkpoint.

Files

  • siglip_classifier_head.pt — trained state_dict for CrossAttentionHead (image_dim=1152, text_dim=384, proj_dim=256, num_heads=4, hidden_dims=(256, 128)).
  • siglip_embeddings.pt — cached SigLIP image / MiniLM text embeddings used during training (for reproducing splits/evaluation only; not required for inference).

Architecture

Both embeddings are linearly projected into a shared 256-d space and treated as length-1 sequences. A text→image attention head lets the question query the document representation, and a symmetric image→text head lets the document representation query the question, each followed by a residual connection and LayerNorm (transformer-block style). The two attended vectors are concatenated (512-d) and passed through a small MLP classifier.

import torch
from torch import nn

class CrossAttentionHead(nn.Module):
    def __init__(self, image_dim=1152, text_dim=384, proj_dim=256, num_heads=4, hidden_dims=(256, 128)):
        super().__init__()
        self.image_proj = nn.Linear(image_dim, proj_dim)
        self.text_proj = nn.Linear(text_dim, proj_dim)
        self.text_to_image_attn = nn.MultiheadAttention(proj_dim, num_heads, batch_first=True)
        self.image_to_text_attn = nn.MultiheadAttention(proj_dim, num_heads, batch_first=True)
        self.norm_text = nn.LayerNorm(proj_dim)
        self.norm_image = nn.LayerNorm(proj_dim)
        h1, h2 = hidden_dims
        self.classifier = nn.Sequential(
            nn.Linear(proj_dim * 2, h1),
            nn.ReLU(),
            nn.Linear(h1, h2),
            nn.ReLU(),
            nn.Linear(h2, 1),
        )

    def forward(self, image_embed, text_embed):
        img = self.image_proj(image_embed).unsqueeze(1)
        txt = self.text_proj(text_embed).unsqueeze(1)

        text_attended, _ = self.text_to_image_attn(query=txt, key=img, value=img)
        image_attended, _ = self.image_to_text_attn(query=img, key=txt, value=txt)

        text_fused = self.norm_text(txt + text_attended).squeeze(1)
        image_fused = self.norm_image(img + image_attended).squeeze(1)

        fused = torch.cat([text_fused, image_fused], dim=-1)
        return self.classifier(fused).squeeze(-1)

Usage

import torch
from huggingface_hub import hf_hub_download

ckpt = hf_hub_download("giacolees/siglip-doc-understanding-classifier", "siglip_classifier_head.pt")
head = CrossAttentionHead()
head.load_state_dict(torch.load(ckpt, map_location="cpu"))
head.eval()

# image_embed: SigLIP pooled image features (1152-d), via
#   AutoModel.from_pretrained("google/siglip-so400m-patch14-384").get_image_features(...)
# text_embed: MiniLM sentence embedding (384-d), via
#   SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2").encode(question)
with torch.no_grad():
    logit = head(image_embed.unsqueeze(0), text_embed.unsqueeze(0))
    prob_unanswerable = torch.sigmoid(logit).item()

Encoders used to produce the embeddings:

  • Image: google/siglip-so400m-patch14-384
  • Text: sentence-transformers/all-MiniLM-L6-v2

Results

Held-out test split (n=2,158), trained/evaluated on the identical embedding cache and train/val/test split as the earlier concat-based head:

Fusion Acc Prec Rec F1 MCC nlp_entity F1 element F1 layout F1
Concat (superseded) 0.702 0.711 0.678 0.695 0.404 0.738 0.509 0.336
Cross-attention (this checkpoint) 0.823 0.828 0.815 0.821 0.645 0.865 0.693 0.318

Trained as part of the multimodalDocumentUnderstanding benchmark (unanswerable question detection from document images).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train giacolees/siglip-doc-understanding-classifier