YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

PaliGemma Checkbox Classifier (Head Fine-tune)

This repository contains a lightweight classifier head trained on top of frozen PaliGemma-3B vision embeddings to classify checkbox state:

  • unchecked
  • checked

Note: The PaliGemma backbone (google/paligemma-3b-mix-224) is loaded from Hugging Face at runtime.
This repo stores only the trained classifier head and configs.


What’s inside

Model artifacts

  • classifier_head.pt — trained PyTorch classifier head weights
  • model_config.json — inference configuration (backbone + head input dim + labels)
  • training_config.json — training hyperparameters and setup
  • model_info.json — short metadata about the model

How to use (inference)

1) Install dependencies

pip install torch transformers pillow

2) Run inference

import json
import torch
import torch.nn as nn
from PIL import Image
from transformers import AutoProcessor, PaliGemmaModel

DEVICE = "cpu"

MODEL_DIR = "."  # folder containing classifier_head.pt + model_config.json

with open(f"{MODEL_DIR}/model_config.json", "r") as f:
    cfg = json.load(f)

# Load backbone + processor
processor = AutoProcessor.from_pretrained(cfg["backbone"])
backbone = PaliGemmaModel.from_pretrained(cfg["backbone"]).to(DEVICE)
backbone.eval()

# Rebuild classifier head (must match training)
classifier = nn.Sequential(
    nn.Linear(cfg["classifier_architecture"]["input_dim"], 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 2)
).to(DEVICE)

classifier.load_state_dict(
    torch.load(f"{MODEL_DIR}/classifier_head.pt", map_location=DEVICE)
)
classifier.eval()

@torch.no_grad()
def extract_features(image: Image.Image):
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)
    outputs = backbone(**inputs)
    return outputs.last_hidden_state.mean(dim=1)

@torch.no_grad()
def predict(image_path: str):
    image = Image.open(image_path).convert("RGB")
    features = extract_features(image)
    logits = classifier(features)
    pred = logits.argmax(dim=1).item()
    return "checked" if pred == 1 else "unchecked"

print(predict("sample.jpg"))

Labels

  • 0unchecked
  • 1checked

Notes

  • Trained under limited compute (Google Colab).
  • Only the classifier head is trained; the PaliGemma backbone remains frozen.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support