AVE / models.py
Mandour
update token
9bb64ca
from transformers import (AutoProcessor,
RobertaConfig,
BertTokenizerFast,
RobertaTokenizerFast,
RobertaModel,
BlipForQuestionAnswering)
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
# Load environment variables (optional for local dev; Spaces use web UI for env vars)
if os.path.exists('.env'):
from dotenv import load_dotenv
load_dotenv()
ATTRIBUTES_LIST = ['sleeve', 'type', 'pattern', 'material',
'neck', 'color', 'style', 'brand', 'gender']
HF_CACHE_DIR = "./hf_cache"
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def get_tokenizers():
bert_tokenizer = BertTokenizerFast.from_pretrained(
"google-bert/bert-base-uncased", cache_dir=HF_CACHE_DIR)
roberta_tokenizer = RobertaTokenizerFast.from_pretrained(
"FacebookAI/roberta-base", cache_dir=HF_CACHE_DIR)
bert_tokenizer.add_special_tokens({'bos_token': '[DEC]'})
return bert_tokenizer, roberta_tokenizer
def get_image_processor():
return AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=HF_CACHE_DIR)
class AttentionModalityMerger(nn.Module):
def __init__(self, text_dim, image_dim):
super().__init__()
self.text_layer_norm = nn.LayerNorm(text_dim)
self.image_layer_norm = nn.LayerNorm(image_dim)
self.linear = nn.Linear(
in_features=image_dim + text_dim, out_features=1)
self.sigmoid = nn.Sigmoid()
def forward(self, text_embedds, image_features, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(
-1).expand(text_embedds.size()).float()
text_embedds = input_mask_expanded * text_embedds
text_embedds = text_embedds.sum(dim=1)
text_embedds_norm = self.text_layer_norm(text_embedds)
image_features = image_features.sum(dim=1)
image_features_norm = self.image_layer_norm(image_features)
text_image_embedds = torch.cat(
[text_embedds_norm, image_features_norm], axis=-1)
gate_output = self.linear(text_image_embedds)
p_txt = self.sigmoid(gate_output)
p_img = 1 - p_txt
scaled_text = p_txt * text_embedds_norm
scaled_image = p_img * image_features_norm
final_output = torch.cat([scaled_text, scaled_image], dim=-1)
return final_output, p_txt, p_img
class RobertaTokenClassificationWithCRF(nn.Module):
def __init__(self, vocab_size, device, roberta_token=None):
if roberta_token is None:
roberta_token = os.getenv("ROBERTA_TOKEN")
super().__init__()
self.vocab_size = vocab_size
self.config = RobertaConfig()
self.roberta = RobertaModel.from_pretrained(
"FacebookAI/roberta-base", output_hidden_states=True, cache_dir=HF_CACHE_DIR)
self.freeze_layers()
self._loadTextWeights(device, roberta_token)
def _loadTextWeights(self, device, roberta_token):
repo_id = "LomaaZakaria/Roberta_Attribute_Value_Extraction_Model"
weights_file_name = "RobertaCRFWithNOAnswerClassifier_OnFashionGenData_2epochs.pth"
weights_file_path = hf_hub_download(
repo_id=repo_id, filename=weights_file_name, token=roberta_token, cache_dir=HF_CACHE_DIR)
state_dict = torch.load(
weights_file_path, weights_only=True, map_location=device)
text_model_state_dict = self.roberta.state_dict()
filtered_state_dict = {
k: v for k, v in state_dict.items()
if k in text_model_state_dict and v.shape == text_model_state_dict[k].shape
}
self.roberta.load_state_dict(filtered_state_dict, strict=False)
def freeze_layers(self):
self.roberta.embeddings.requires_grad_(False)
for layers in self.roberta.encoder.layer[:8]:
for p in layers.parameters():
p.requires_grad = False
def forward(self, token_ids, attention_mask):
outputs = self.roberta(input_ids=token_ids,
attention_mask=attention_mask)
last_hidden_state = outputs.hidden_states[-1]
return last_hidden_state
class ImageModel(nn.Module):
def __init__(self):
super(ImageModel, self).__init__()
self.vision_model = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).vision_model
self._freezeLayers()
def _freezeLayers(self):
self.vision_model.embeddings.requires_grad_(False)
for layer in self.vision_model.encoder.layers[:8]:
for p in layer.parameters():
p.requires_grad = False
def forward(self, x):
return self.vision_model(x).last_hidden_state
class MergerModel(nn.Module):
def __init__(self, vocab_size, device, roberta_token=None):
if roberta_token is None:
roberta_token = os.getenv("ROBERTA_TOKEN")
super().__init__()
self.text_decoder = BlipForQuestionAnswering.from_pretrained(
"Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).text_decoder
self.text_encoder = RobertaTokenClassificationWithCRF(
vocab_size, device, roberta_token)
self.vision_model = ImageModel()
text_dim, image_dim = self.text_encoder.config.hidden_size, 768
self.attention_merger = AttentionModalityMerger(text_dim, image_dim)
self.linear = nn.Linear(in_features=text_dim +
image_dim, out_features=text_dim)
def forward(self, **inputs):
text_encoder = self.text_encoder(
token_ids=inputs['encoder_token_ids'], attention_mask=inputs['encoder_attention_mask'])
vision_encoder = self.vision_model(x=inputs['image'])
merger_output, p_txt, p_img = self.attention_merger(
text_encoder, vision_encoder, attention_mask=inputs['encoder_attention_mask'])
merger_output = merger_output.unsqueeze(1)
batch_size = vision_encoder.shape[0]
merger_output_mask = torch.ones(
(batch_size, 1), dtype=torch.long, device=vision_encoder.device)
merger_output_linear = self.linear(merger_output)
decoder_output = self.text_decoder(
input_ids=inputs['decoder_input_token_ids'],
attention_mask=inputs['decoder_input_attention_mask'],
encoder_hidden_states=merger_output_linear,
encoder_attention_mask=merger_output_mask,
return_dict=True,
return_logits=True
)
logits = decoder_output
return logits, p_txt, p_img
def load_merger_model(bert_tokenizer, device, model_token=None):
if model_token is None:
model_token = os.getenv("MERGER_MODEL_TOKEN")
print("MERGER_MODEL_TOKEN is set:", model_token is not None)
vocab_size = len(bert_tokenizer)
model = MergerModel(vocab_size, device)
repo_id = "MohamedMosilhy/AttentionMergerModality"
weights_file_name = "Freezing_More_NewViTBlipAttentionMergerModality_4epochs_2e_5_withwarmup.pth"
weights_file_path = hf_hub_download(
repo_id=repo_id, filename=weights_file_name, token=model_token, cache_dir=HF_CACHE_DIR)
model.load_state_dict(torch.load(
weights_file_path, weights_only=True, map_location=device))
model.to(device)
model.eval()
return model
def model_generate(model, data, text_tokenizer, device, labels=None, max_generated_length=50, testing=False, return_confidence=False):
if labels is None:
labels = '[DEC]'
token_labels = text_tokenizer.convert_tokens_to_ids([labels])
else:
token_labels = text_tokenizer.convert_tokens_to_ids([labels])
model.eval()
confidences = []
for index in range(max_generated_length):
decoder_inputs = text_tokenizer(
text=labels, max_length=65, padding='max_length', add_special_tokens=False, return_tensors="pt")
decoder_data = {
"decoder_input_token_ids": decoder_inputs['input_ids'],
"decoder_input_attention_mask": decoder_inputs['attention_mask']
}
inputs = {
"image": data['image'].unsqueeze(0).to(device),
"encoder_token_ids": data['encoder_token_ids'].unsqueeze(0).to(device),
"encoder_attention_mask": data['encoder_attention_mask'].unsqueeze(0).to(device),
"decoder_input_token_ids": decoder_data['decoder_input_token_ids'].to(device),
"decoder_input_attention_mask": decoder_data['decoder_input_attention_mask'].to(device)
}
with torch.no_grad():
logits, _, _ = model(**inputs)
probs = F.softmax(logits, dim=-1)
predicated_label = torch.argmax(
probs[:, index, :], dim=-1).cpu().numpy()
# Get confidence for this token
confidence = float(
probs[0, index, predicated_label[0]].cpu().item())
confidences.append(confidence)
token_labels.append(predicated_label[0])
predicted_tokens = text_tokenizer.convert_ids_to_tokens(
predicated_label)
labels = text_tokenizer.decode(token_labels)
if predicted_tokens[0] == text_tokenizer.sep_token:
break
predicated_attribute_value = text_tokenizer.decode(token_labels)
if testing:
token_labels = np.array(token_labels)
dec_token_id = text_tokenizer.bos_token_id
token_labels = token_labels[token_labels != dec_token_id]
return token_labels
if return_confidence:
# Use the minimum confidence across the generated tokens as the attribute confidence
return predicated_attribute_value, min(confidences) if confidences else 0.0
return predicated_attribute_value
# Define which attributes are relevant for each category
CATEGORY_ATTRIBUTES = {
"clothing": ['sleeve', 'type', 'pattern', 'material', 'neck', 'color', 'style', 'brand', 'gender'],
"bags": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
"shoes": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
"accessories": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'],
}
def get_predicated_values(
model, category, img, desc, image_processor, bert_tokenizer, roberta_tokenizer, device, max_seq_length=256
):
results = []
def _combined_with_CategoriesAttributes(desc, category, attribute):
return category + ' ' + attribute
def imageProcesser(img):
return image_processor(img)
def _tokenizeText(image, desc, category, attribute):
combined_desc = _combined_with_CategoriesAttributes(
desc, category, attribute)
image_inputs = imageProcesser(image)
text_encoder_inputs = roberta_tokenizer(
combined_desc,
desc,
max_length=max_seq_length,
padding='max_length',
return_tensors='np'
)
return image_inputs, text_encoder_inputs
# Normalize category to lower-case and pick attributes
category_key = str(category).strip().lower()
attributes = CATEGORY_ATTRIBUTES.get(category_key, CATEGORY_ATTRIBUTES["clothing"])
image = img
for attribute in attributes:
image_inputs, text_encoder_inputs = _tokenizeText(
image, desc, category, attribute)
image_data = torch.from_numpy(np.array(image_inputs['pixel_values']))
encoder_token_ids = torch.from_numpy(
np.array(text_encoder_inputs['input_ids']))
encoder_attn_mask = torch.from_numpy(
np.array(text_encoder_inputs['attention_mask']))
inputs = {
"image": image_data.squeeze(0),
"encoder_token_ids": encoder_token_ids.squeeze(0),
"encoder_attention_mask": encoder_attn_mask.squeeze(0),
}
predicated_value, confidence = model_generate(
model, inputs, text_tokenizer=bert_tokenizer, device=device, return_confidence=True
)
# Remove [DEC] and [SEP] tokens and strip whitespace
clean_value = predicated_value.replace('[DEC]', '').replace('[SEP]', '').strip()
if clean_value != 'not specified':
results.append(
{"name": attribute, "value": clean_value,
"confidence": float(confidence)}
)
return results