|
|
""" |
|
|
ColorCLIP model for learning color-aligned embeddings. |
|
|
This file contains the ColorCLIP model that learns to encode images and texts |
|
|
in an embedding space specialized for color representation. It includes |
|
|
a ResNet-based image encoder, a text encoder with custom tokenizer, |
|
|
and contrastive loss functions for training. |
|
|
""" |
|
|
|
|
|
import config |
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pandas as pd |
|
|
from tqdm.auto import tqdm |
|
|
from collections import defaultdict |
|
|
from typing import Optional, List |
|
|
import logging |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class ColorDataset(Dataset): |
|
|
""" |
|
|
Dataset class for color embedding training. |
|
|
|
|
|
Handles loading images from local paths and tokenizing text descriptions |
|
|
for training the ColorCLIP model. |
|
|
""" |
|
|
|
|
|
def __init__(self, dataframe, tokenizer, transform=None): |
|
|
""" |
|
|
Initialize the color dataset. |
|
|
|
|
|
Args: |
|
|
dataframe: DataFrame with columns for image paths and text descriptions |
|
|
tokenizer: Tokenizer instance that converts text to list of integers (tokens) |
|
|
transform: Optional image transformations (default: standard ImageNet normalization) |
|
|
""" |
|
|
self.df = dataframe.reset_index(drop=True) |
|
|
self.tokenizer = tokenizer |
|
|
self.transform = transform or transforms.Compose([ |
|
|
transforms.Resize((224,224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485,0.456,0.406], |
|
|
std=[0.229,0.224,0.225]) |
|
|
]) |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the number of samples in the dataset.""" |
|
|
return len(self.df) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
""" |
|
|
Get a sample from the dataset. |
|
|
|
|
|
Args: |
|
|
idx: Index of the sample |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_tensor, token_tensor) |
|
|
""" |
|
|
row = self.df.iloc[idx] |
|
|
|
|
|
img_path = row[config.column_local_image_path] |
|
|
img = Image.open(img_path).convert("RGB") |
|
|
img = self.transform(img) |
|
|
tokens = torch.tensor(self.tokenizer(row[config.text_column]), dtype=torch.long) |
|
|
return img, tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tokenizer: |
|
|
""" |
|
|
Tokenizer for extracting color-related keywords from text. |
|
|
|
|
|
This tokenizer filters text to keep only color-related words and basic |
|
|
descriptive words, then maps them to integer indices for embedding. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
""" |
|
|
Initialize the tokenizer. |
|
|
|
|
|
Creates empty word-to-index and index-to-word mappings. |
|
|
Index 0 is reserved for padding/unknown tokens. |
|
|
""" |
|
|
self.word2idx = defaultdict(lambda: 0) |
|
|
self.idx2word = {} |
|
|
self.counter = 1 |
|
|
|
|
|
def preprocess_text(self, text): |
|
|
""" |
|
|
Extract color-related keywords from text. |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
|
|
|
Returns: |
|
|
Preprocessed text containing only color and descriptive keywords |
|
|
""" |
|
|
|
|
|
color_keywords = ['red', 'blue', 'green', 'yellow', 'purple', 'pink', 'orange', |
|
|
'brown', 'black', 'white', 'gray', 'navy', 'beige', 'aqua', 'lime', |
|
|
'violet', 'turquoise', 'teal', 'tan', 'snow', 'silver', 'plum', |
|
|
'olive', 'fuchsia', 'gold', 'cream', 'ivory', 'maroon'] |
|
|
|
|
|
|
|
|
descriptive_words = ['shirt', 'dress', 'top', 'bottom', 'shoe', 'bag', 'hat', 'short', 'long', 'sleeve'] |
|
|
|
|
|
words = text.lower().split() |
|
|
filtered_words = [] |
|
|
for word in words: |
|
|
|
|
|
if word in color_keywords or word in descriptive_words: |
|
|
filtered_words.append(word) |
|
|
|
|
|
return ' '.join(filtered_words) if filtered_words else text.lower() |
|
|
|
|
|
def fit(self, texts): |
|
|
""" |
|
|
Build vocabulary from a list of texts. |
|
|
|
|
|
Args: |
|
|
texts: List of text strings to build vocabulary from |
|
|
""" |
|
|
for text in texts: |
|
|
processed_text = self.preprocess_text(text) |
|
|
for word in processed_text.split(): |
|
|
if word not in self.word2idx: |
|
|
self.word2idx[word] = self.counter |
|
|
self.idx2word[self.counter] = word |
|
|
self.counter += 1 |
|
|
|
|
|
def __call__(self, text): |
|
|
""" |
|
|
Tokenize a text string into a list of integer indices. |
|
|
|
|
|
Args: |
|
|
text: Input text string |
|
|
|
|
|
Returns: |
|
|
List of integer token indices |
|
|
""" |
|
|
processed_text = self.preprocess_text(text) |
|
|
return [self.word2idx[word] for word in processed_text.split()] |
|
|
|
|
|
def load_vocab(self, word2idx_dict): |
|
|
""" |
|
|
Load vocabulary from a word-to-index dictionary. |
|
|
|
|
|
Args: |
|
|
word2idx_dict: Dictionary mapping words to indices |
|
|
""" |
|
|
self.word2idx = defaultdict(lambda: 0, {k: int(v) for k, v in word2idx_dict.items()}) |
|
|
self.idx2word = {int(v): k for k, v in word2idx_dict.items() if int(v) > 0} |
|
|
self.counter = max(self.word2idx.values(), default=0) + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
|
""" |
|
|
Image encoder based on ResNet18 for extracting image embeddings. |
|
|
|
|
|
Uses a pretrained ResNet18 backbone and replaces the final layer |
|
|
to output embeddings of the specified dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, embedding_dim=config.color_emb_dim): |
|
|
""" |
|
|
Initialize the image encoder. |
|
|
|
|
|
Args: |
|
|
embedding_dim: Dimension of the output embedding (default: color_emb_dim) |
|
|
""" |
|
|
super().__init__() |
|
|
self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) |
|
|
self.backbone.fc = nn.Sequential( |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(self.backbone.fc.in_features, embedding_dim) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass through the image encoder. |
|
|
|
|
|
Args: |
|
|
x: Image tensor [batch_size, channels, height, width] |
|
|
|
|
|
Returns: |
|
|
Normalized image embeddings [batch_size, embedding_dim] |
|
|
""" |
|
|
x = self.backbone(x) |
|
|
return F.normalize(x, dim=-1) |
|
|
|
|
|
class TextEncoder(nn.Module): |
|
|
""" |
|
|
Text encoder for extracting text embeddings from token sequences. |
|
|
|
|
|
Uses an embedding layer followed by mean pooling (with optional length normalization) |
|
|
and a linear projection to the output embedding dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, vocab_size, embedding_dim=config.color_emb_dim): |
|
|
""" |
|
|
Initialize the text encoder. |
|
|
|
|
|
Args: |
|
|
vocab_size: Size of the vocabulary |
|
|
embedding_dim: Dimension of the output embedding (default: color_emb_dim) |
|
|
""" |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size, 32, padding_idx=0) |
|
|
self.dropout = nn.Dropout(0.1) |
|
|
self.fc = nn.Linear(32, embedding_dim) |
|
|
|
|
|
def forward(self, x, lengths=None): |
|
|
""" |
|
|
Forward pass through the text encoder. |
|
|
|
|
|
Args: |
|
|
x: Token tensor [batch_size, sequence_length] |
|
|
lengths: Optional sequence lengths tensor [batch_size] for proper mean pooling |
|
|
|
|
|
Returns: |
|
|
Normalized text embeddings [batch_size, embedding_dim] |
|
|
""" |
|
|
emb = self.embedding(x) |
|
|
emb = self.dropout(emb) |
|
|
if lengths is not None: |
|
|
summed = emb.sum(dim=1) |
|
|
mean = summed / lengths.unsqueeze(1).clamp_min(1) |
|
|
else: |
|
|
mean = emb.mean(dim=1) |
|
|
return F.normalize(self.fc(mean), dim=-1) |
|
|
|
|
|
class ColorCLIP(nn.Module): |
|
|
""" |
|
|
Color CLIP model for learning color-aligned image-text embeddings. |
|
|
""" |
|
|
def __init__(self, vocab_size, embedding_dim=config.color_emb_dim, tokenizer=None): |
|
|
""" |
|
|
Initialize ColorCLIP model. |
|
|
|
|
|
Args: |
|
|
vocab_size: Size of the vocabulary for text encoding |
|
|
embedding_dim: Dimension of the embedding space (default: color_emb_dim) |
|
|
tokenizer: Optional Tokenizer instance (will create one if None) |
|
|
""" |
|
|
super().__init__() |
|
|
self.vocab_size = vocab_size |
|
|
self.embedding_dim = embedding_dim |
|
|
self.image_encoder = ImageEncoder(embedding_dim) |
|
|
self.text_encoder = TextEncoder(vocab_size, embedding_dim) |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def forward(self, image, text, lengths=None): |
|
|
""" |
|
|
Forward pass through the model. |
|
|
|
|
|
Args: |
|
|
image: Image tensor [B, C, H, W] |
|
|
text: Text token tensor [B, T] |
|
|
lengths: Optional sequence lengths tensor [B] |
|
|
|
|
|
Returns: |
|
|
Tuple of (image_embeddings, text_embeddings) |
|
|
""" |
|
|
return self.image_encoder(image), self.text_encoder(text, lengths) |
|
|
|
|
|
def get_text_embeddings(self, texts: List[str]) -> torch.Tensor: |
|
|
""" |
|
|
Get text embeddings for a list of text strings. |
|
|
|
|
|
Args: |
|
|
texts: List of text strings |
|
|
|
|
|
Returns: |
|
|
Text embeddings tensor [batch_size, embedding_dim] |
|
|
""" |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer must be set before calling get_text_embeddings") |
|
|
|
|
|
token_lists = [self.tokenizer(t) for t in texts] |
|
|
max_len = max((len(toks) for toks in token_lists), default=0) |
|
|
padded = [toks + [0] * (max_len - len(toks)) for toks in token_lists] |
|
|
input_ids = torch.tensor(padded, dtype=torch.long, device=next(self.parameters()).device) |
|
|
lengths = torch.tensor([len(toks) for toks in token_lists], dtype=torch.long, device=input_ids.device) |
|
|
with torch.no_grad(): |
|
|
emb = self.text_encoder(input_ids, lengths) |
|
|
return emb |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path: str, vocab_path: Optional[str] = None, device: str = "cpu", repo_id: Optional[str] = None): |
|
|
""" |
|
|
Load a pretrained ColorCLIP model from a file path or Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model checkpoint (.pt file) or filename if using repo_id |
|
|
vocab_path: Optional path to tokenizer vocabulary JSON file or filename if using repo_id |
|
|
device: Device to load the model on (default: "cpu") |
|
|
repo_id: Optional Hugging Face repository ID (e.g., "username/model-name") |
|
|
If provided, model_path and vocab_path should be filenames within the repo |
|
|
|
|
|
Returns: |
|
|
ColorCLIP model instance |
|
|
|
|
|
Example: |
|
|
# Load from local file |
|
|
model = ColorCLIP.from_pretrained("color_model.pt", "tokenizer_vocab.json") |
|
|
|
|
|
# Load from Hugging Face Hub |
|
|
from huggingface_hub import hf_hub_download |
|
|
model_file = hf_hub_download(repo_id="username/model-name", filename="color_model.pt") |
|
|
vocab_file = hf_hub_download(repo_id="username/model-name", filename="tokenizer_vocab.json") |
|
|
model = ColorCLIP.from_pretrained(model_file, vocab_file) |
|
|
""" |
|
|
device_obj = torch.device(device) |
|
|
|
|
|
|
|
|
if repo_id: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_path) |
|
|
if vocab_path: |
|
|
vocab_path = hf_hub_download(repo_id=repo_id, filename=vocab_path) |
|
|
except ImportError: |
|
|
raise ImportError("huggingface_hub is required to load models from Hugging Face. Install it with: pip install huggingface-hub") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device_obj) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
|
|
|
vocab_size = checkpoint.get('vocab_size', None) |
|
|
embedding_dim = checkpoint.get('embedding_dim', 16) |
|
|
|
|
|
|
|
|
if vocab_size is None: |
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
|
if 'text_encoder.embedding.weight' in state_dict: |
|
|
vocab_size = state_dict['text_encoder.embedding.weight'].shape[0] |
|
|
else: |
|
|
raise ValueError("Could not determine vocab_size from checkpoint") |
|
|
|
|
|
|
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) |
|
|
else: |
|
|
raise ValueError("Checkpoint must be a dictionary") |
|
|
|
|
|
|
|
|
model = cls(vocab_size=vocab_size, embedding_dim=embedding_dim) |
|
|
model.load_state_dict(state_dict) |
|
|
model = model.to(device_obj) |
|
|
|
|
|
|
|
|
if vocab_path and os.path.exists(vocab_path): |
|
|
tokenizer = Tokenizer() |
|
|
with open(vocab_path, 'r') as f: |
|
|
vocab_dict = json.load(f) |
|
|
tokenizer.load_vocab(vocab_dict) |
|
|
model.tokenizer = tokenizer |
|
|
|
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory: str, vocab_path: Optional[str] = None): |
|
|
""" |
|
|
Save the model and optionally the tokenizer vocabulary. |
|
|
|
|
|
Args: |
|
|
save_directory: Directory to save the model |
|
|
vocab_path: Optional path to save tokenizer vocabulary |
|
|
""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
|
|
|
model_path = os.path.join(save_directory, config.color_model_path) |
|
|
checkpoint = { |
|
|
'model_state_dict': self.state_dict(), |
|
|
'vocab_size': self.vocab_size, |
|
|
'embedding_dim': self.embedding_dim |
|
|
} |
|
|
torch.save(checkpoint, model_path) |
|
|
|
|
|
|
|
|
if self.tokenizer is not None: |
|
|
vocab_dict = dict(self.tokenizer.word2idx) |
|
|
if vocab_path is None: |
|
|
vocab_path = os.path.join(save_directory, config.tokeniser_path) |
|
|
with open(vocab_path, 'w') as f: |
|
|
json.dump(vocab_dict, f) |
|
|
|
|
|
return model_path, vocab_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clip_loss(image_emb, text_emb, temperature=0.07): |
|
|
""" |
|
|
CLIP contrastive loss function. |
|
|
|
|
|
Args: |
|
|
image_emb: Image embeddings [batch_size, embedding_dim] |
|
|
text_emb: Text embeddings [batch_size, embedding_dim] |
|
|
temperature: Temperature scaling parameter |
|
|
|
|
|
Returns: |
|
|
Contrastive loss value |
|
|
""" |
|
|
logits = image_emb @ text_emb.T / temperature |
|
|
labels = torch.arange(len(image_emb), device=image_emb.device) |
|
|
loss_i2t = F.cross_entropy(logits, labels) |
|
|
loss_t2i = F.cross_entropy(logits.T, labels) |
|
|
return (loss_i2t + loss_t2i) / 2 |
|
|
|
|
|
def collate_batch(batch): |
|
|
""" |
|
|
Collate function for DataLoader that pads sequences and filters None values. |
|
|
|
|
|
Args: |
|
|
batch: List of (image, tokens) tuples or None |
|
|
|
|
|
Returns: |
|
|
Tuple of (images, padded_tokens, lengths) or None if batch is empty |
|
|
""" |
|
|
batch = [b for b in batch if b is not None] |
|
|
if len(batch) == 0: |
|
|
return None |
|
|
imgs, tokens = zip(*batch) |
|
|
imgs = torch.stack(imgs, dim=0) |
|
|
lengths = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) |
|
|
tokens_padded = nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0) |
|
|
return imgs, tokens_padded, lengths |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
Training script for ColorCLIP model. |
|
|
This code only runs when the file is executed directly, not when imported. |
|
|
""" |
|
|
|
|
|
batch_size = 16 |
|
|
lr = 1e-4 |
|
|
epochs=50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = Tokenizer() |
|
|
df = pd.read_csv(config.local_dataset_path) |
|
|
|
|
|
|
|
|
main_colors = ['beige', 'black', 'blue', 'brown', 'green', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'] |
|
|
df = df[df[config.color_column].isin(main_colors)].copy() |
|
|
print(f"📊 Filtered dataset: {len(df)} samples with {len(main_colors)} colors") |
|
|
print(f"🎨 Colors: {sorted(df[config.color_column].unique())}") |
|
|
|
|
|
tokenizer.fit(df[config.text_column].tolist()) |
|
|
|
|
|
|
|
|
df_local = df[df[config.column_local_image_path].astype(str).str.len() > 0] |
|
|
df_local = df_local[df_local[config.column_local_image_path].apply(lambda p: os.path.isfile(p))] |
|
|
df_local = df_local.reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
df_local = df_local.sample(frac=1.0, random_state=42).reset_index(drop=True) |
|
|
split_idx = int(0.9 * len(df_local)) |
|
|
df_train = df_local.iloc[:split_idx].reset_index(drop=True) |
|
|
df_test = df_local.iloc[split_idx:].reset_index(drop=True) |
|
|
|
|
|
|
|
|
train_dataset = ColorDataset(df_train, tokenizer) |
|
|
test_dataset = ColorDataset(df_test, tokenizer) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch, num_workers=0) |
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch, num_workers=0) |
|
|
|
|
|
device = config.device |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
model = ColorCLIP(vocab_size=tokenizer.counter, embedding_dim=config.color_emb_dim, tokenizer=tokenizer).to(device) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) |
|
|
|
|
|
|
|
|
here = os.path.dirname(__file__) |
|
|
vocab_out = os.path.join(here, config.tokeniser_path) |
|
|
with open(vocab_out, "w") as f: |
|
|
json.dump(dict(tokenizer.word2idx), f) |
|
|
print(f"Tokenizer vocabulary saved to: {vocab_out}") |
|
|
|
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - train", leave=False) |
|
|
epoch_losses = [] |
|
|
for batch in train_loader: |
|
|
if batch is None: |
|
|
pbar.update(1) |
|
|
continue |
|
|
imgs, texts, lengths = batch |
|
|
imgs = imgs.to(device) |
|
|
texts = texts.to(device) |
|
|
lengths = lengths.to(device) |
|
|
optimizer.zero_grad() |
|
|
img_emb, text_emb = model(imgs, texts, lengths) |
|
|
loss = clip_loss(img_emb, text_emb) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
epoch_losses.append(loss.item()) |
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "avg": f"{sum(epoch_losses)/len(epoch_losses):.4f}"}) |
|
|
pbar.update(1) |
|
|
pbar.close() |
|
|
|
|
|
avg_train_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else None |
|
|
if avg_train_loss is not None: |
|
|
print(f"[Train] Epoch {epoch+1}/{epochs} - avg loss: {avg_train_loss:.4f}") |
|
|
else: |
|
|
print(f"[Train] Epoch {epoch+1}/{epochs} - no valid batches") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
test_losses = [] |
|
|
with torch.no_grad(): |
|
|
pbar_t = tqdm(total=len(test_loader), desc=f"Epoch {epoch+1}/{epochs} - test", leave=False) |
|
|
for batch in test_loader: |
|
|
if batch is None: |
|
|
pbar_t.update(1) |
|
|
continue |
|
|
imgs, texts, lengths = batch |
|
|
imgs = imgs.to(device) |
|
|
texts = texts.to(device) |
|
|
lengths = lengths.to(device) |
|
|
img_emb, text_emb = model(imgs, texts, lengths) |
|
|
test_losses.append(clip_loss(img_emb, text_emb).item()) |
|
|
pbar_t.update(1) |
|
|
pbar_t.close() |
|
|
if len(test_losses) > 0: |
|
|
avg_test_loss = sum(test_losses) / len(test_losses) |
|
|
print(f"[Test ] Epoch {epoch+1}/{epochs} - avg loss: {avg_test_loss:.4f}") |
|
|
else: |
|
|
print(f"[Test ] Epoch {epoch+1}/{epochs} - no valid batches") |
|
|
|
|
|
|
|
|
ckpt_dir = here |
|
|
latest_path = os.path.join(ckpt_dir, config.color_model_path) |
|
|
epoch_path = os.path.join(ckpt_dir, f"color_model_epoch_{epoch+1}.pt") |
|
|
checkpoint = { |
|
|
'model_state_dict': model.state_dict(), |
|
|
'vocab_size': model.vocab_size, |
|
|
'embedding_dim': model.embedding_dim |
|
|
} |
|
|
torch.save(checkpoint, latest_path) |
|
|
torch.save(checkpoint, epoch_path) |
|
|
print(f"[Save ] Saved checkpoints: {latest_path} and {epoch_path}") |