|
|
--- |
|
|
library_name: transformers |
|
|
tags: |
|
|
- vision |
|
|
- cell-biology |
|
|
- dino |
|
|
pipeline_tag: image-feature-extraction |
|
|
model-index: |
|
|
- name: cellrepDINO |
|
|
results: [] |
|
|
--- |
|
|
|
|
|
# CellrepDINO Model |
|
|
|
|
|
This is a custom DINO model for extracting rich representations of cell microscopy in condensed vector/array form. The forward method of the cellrepDINO model gives embeddings that can be used |
|
|
for relevant downstream tasks like perturbation prediction, mechanism of action (MoA) classification, nuclei size shape estimation, etc. Simply train a basic linear or logistic model using the embeddings. |
|
|
|
|
|
## Model Details |
|
|
- Architecture: DINOv2 |
|
|
- Default ViT Model Size: Large |
|
|
- Patch Size: 14 |
|
|
- Default image size: 1024 |
|
|
- Default resize size: 518 |
|
|
- Default center crop: 518 |
|
|
|
|
|
## Setup |
|
|
|
|
|
Please create an environment and run `pip install torch transformers Pillow numpy pandas torchvision omegaconf` . Activate this new environment. |
|
|
|
|
|
## Example Usage |
|
|
|
|
|
There are different types of embeddings of embeddings one can extract, we recommend the mean/median embeddings over the patch tokens or the class token embedding. |
|
|
|
|
|
|
|
|
``` |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
# Set up device |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
# Load model and processor |
|
|
model = AutoModel.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True, weights_only=True) |
|
|
processor = AutoProcessor.from_pretrained("LPhilllips/cellrepDINO", trust_remote_code=True) |
|
|
|
|
|
# Move model to device |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
# For multiple images: |
|
|
image_paths = ["image1.png", "image2.png"] |
|
|
images = [Image.open(path).convert('RGB') for path in image_paths] |
|
|
|
|
|
# Process batch of images |
|
|
# if you want different rezise and centercrop sizes, please specificy the resize_size, centercrop_size parameters below |
|
|
batch_inputs = processor.preprocess(images=images, resize_size = 518, centercrop_size = 518, return_tensors="pt") |
|
|
|
|
|
# Move image tensors to device |
|
|
batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch_inputs.items()} |
|
|
|
|
|
# Generate embeddings for batch |
|
|
with torch.no_grad(): |
|
|
batch_outputs = model(**batch_inputs) |
|
|
mean_embeddings = batch_outputs['mean_pooled'] |
|
|
median_embeddings = batch_outputs['median_pooled'] |
|
|
cls_embeddings = batch_outputs['cls_token'] |
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
Script for generating embeddings en mass (requires a csv with an `ImagePath` column ): |
|
|
|
|
|
``` |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
from PIL import Image |
|
|
import torch |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import warnings |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def load_model_and_processor(model_name="LPhilllips/cellrepDINO"): |
|
|
"""Load the model and processor, setting up the device.""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model, processor, device |
|
|
|
|
|
def process_batch(image_paths, model, processor, device, batch_size=32, resize_size=518, crop_size=518): |
|
|
"""Process a batch of images and return their embeddings.""" |
|
|
# Load and preprocess images |
|
|
images = [] |
|
|
valid_indices = [] |
|
|
|
|
|
for idx, path in enumerate(image_paths): |
|
|
try: |
|
|
img = Image.open(path) |
|
|
images.append(img) |
|
|
valid_indices.append(idx) |
|
|
except Exception as e: |
|
|
warnings.warn(f"Could not load image {path}: {str(e)}") |
|
|
continue |
|
|
|
|
|
if not images: |
|
|
return None, [] |
|
|
|
|
|
# Process images |
|
|
batch_inputs = processor.preprocess( |
|
|
images=images, |
|
|
resize_size=resize_size, |
|
|
crop_size=crop_size, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
# Move to device |
|
|
batch_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
|
|
for k, v in batch_inputs.items()} |
|
|
|
|
|
# Generate embeddings |
|
|
with torch.no_grad(): |
|
|
embeddings = model.model.forward_features(batch_inputs['pixel_values']) |
|
|
mean_embeddings = embeddings["x_norm_patchtokens"].mean(dim=1) |
|
|
mean_embeddings = mean_embeddings.cpu().numpy() |
|
|
|
|
|
return mean_embeddings, valid_indices |
|
|
|
|
|
def process_and_save_embeddings(csv_path, output_path, batch_size=32): |
|
|
"""Process all images in batches and save results to a feather file.""" |
|
|
# Load model and processor |
|
|
model, processor, device = load_model_and_processor() |
|
|
|
|
|
# Read the CSV file |
|
|
df = pd.read_csv(csv_path) |
|
|
|
|
|
# Initialize lists to store results |
|
|
all_embeddings = [] |
|
|
valid_rows = [] |
|
|
|
|
|
# Process in batches |
|
|
for i in tqdm(range(0, len(df), batch_size)): |
|
|
batch_df = df.iloc[i:i + batch_size] |
|
|
|
|
|
# Process batch |
|
|
embeddings, valid_indices = process_batch( |
|
|
batch_df['ImagePath'].tolist(), |
|
|
model, processor, device, |
|
|
batch_size=batch_size |
|
|
) |
|
|
|
|
|
if embeddings is not None: |
|
|
# Keep track of valid rows and their embeddings |
|
|
valid_batch_rows = batch_df.iloc[valid_indices] |
|
|
all_embeddings.append(embeddings) |
|
|
valid_rows.append(valid_batch_rows) |
|
|
|
|
|
# Combine all results |
|
|
if valid_rows: |
|
|
final_df = pd.concat(valid_rows, ignore_index=True) |
|
|
final_embeddings = np.concatenate(all_embeddings, axis=0) |
|
|
|
|
|
# Add embedding columns to the dataframe |
|
|
embedding_cols = [f'embedding_{i}' for i in range(final_embeddings.shape[1])] |
|
|
embedding_df = pd.DataFrame(final_embeddings, columns=embedding_cols) |
|
|
|
|
|
# Combine metadata with embeddings |
|
|
final_df = pd.concat([final_df, embedding_df], axis=1) |
|
|
|
|
|
# Save to feather |
|
|
final_df.to_feather(output_path) |
|
|
|
|
|
print(f"Successfully processed {len(final_df)} images") |
|
|
print(f"Results saved to {output_path}") |
|
|
else: |
|
|
print("No valid images were processed") |
|
|
|
|
|
|
|
|
# Example usage |
|
|
if __name__ == "__main__": |
|
|
csv_path = "csv/with/image/path/columns" |
|
|
output_path = "/your/output/folder/path" |
|
|
|
|
|
process_and_save_embeddings( |
|
|
csv_path=csv_path, |
|
|
output_path=output_path, |
|
|
batch_size=32 # Adjust based on your GPU memory |
|
|
) |
|
|
``` |