Pretrain like Your Inference: Masked Tuning Improves Zero-Shot Composed Image Retrieval

This repository contains the official pre-trained and tuned model weights (CLIP-ViT-L/14 backbone) for PLI (Pretrain like Your Inference), accepted at ICME 2025.

GitHub arXiv


πŸ“Œ Introduction

Zero-Shot Composed Image Retrieval (ZS-CIR) aims to retrieve a target image based on a query image and a modifying text description, without using labeled triplet data during training.

Existing methods typically rely on vision-language models (e.g., CLIP) which are pre-trained on standard image-text pairs. However, this creates a gap between pre-training (matching static image-text pairs) and inference (matching modified image-text compositions).

PLI (Pretrain like Your Inference) bridges this gap by reformulating contrastive learning as a CIR task using a self-supervised Masked Tuning approach. By randomly masking patches of the input image, we generate triplets of $\langle \text{masked image}, \text{modifying text}, \text{original image} \rangle$, forcing the model to learn fine-grained text-guided modifications during pre-training.


πŸš€ Quick Start & Usage

You can download and load the model weights directly using the huggingface_hub SDK.

1. Installation

Ensure you have the necessary libraries installed:

pip install torch torchvision huggingface_hub clip

2. Loading Weights in Python

Here is an example of how to programmatically download the weight file and load it into your model:

import torch
import clip
from huggingface_hub import hf_hub_download

# 1. Download the weights from Hugging Face
checkpoint_path = hf_hub_download(
    repo_id="jayong/PLI-CLIP-VIT-L-14", 
    filename="best.pth"
)

# 2. Initialize the base CLIP model (ViT-L/14)
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

# 3. Load the tuned weights
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)

print("PLI model weights loaded successfully!")

πŸ› οΈ Method Overview

[ Original Image ]  ──( Patch Masking )──>  [ Masked Image ]
       β”‚                                           β”‚
       β”‚                                   ( Text-guided Modification )
       β–Ό                                           β–Ό
[ Target Representation ] <──( Contrastive )── [ Predicted Representation ]
  1. Patch Masking: Randomly mask patches of the source image.
  2. Text Query: Treat the text description of the image as the "modifying text".
  3. Contrastive Objective: Align the composition of (masked image + text) with the representation of the original image.

✍️ Citation

If you find our work or weights useful in your research, please consider citing our paper:

@inproceedings{chen2025pretrain,
  title={Pretrain like your inference: Masked tuning improves zero-shot composed image retrieval},
  author={Chen, Junyang and Lai, Hanjiang},
  booktitle={2025 IEEE International Conference on Multimedia and Expo (ICME)},
  pages={1--6},
  year={2025},
  organization={IEEE}
}

πŸ“­ Contact / Feedback

For questions or feedback, please raise an issue on our GitHub Repository.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for jayong/PLI-CLIP-VIT-L-14