from transformers import (AutoProcessor, RobertaConfig, BertTokenizerFast, RobertaTokenizerFast, RobertaModel, BlipForQuestionAnswering) from huggingface_hub import hf_hub_download import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os # Load environment variables (optional for local dev; Spaces use web UI for env vars) if os.path.exists('.env'): from dotenv import load_dotenv load_dotenv() ATTRIBUTES_LIST = ['sleeve', 'type', 'pattern', 'material', 'neck', 'color', 'style', 'brand', 'gender'] HF_CACHE_DIR = "./hf_cache" def get_device(): return "cuda" if torch.cuda.is_available() else "cpu" def get_tokenizers(): bert_tokenizer = BertTokenizerFast.from_pretrained( "google-bert/bert-base-uncased", cache_dir=HF_CACHE_DIR) roberta_tokenizer = RobertaTokenizerFast.from_pretrained( "FacebookAI/roberta-base", cache_dir=HF_CACHE_DIR) bert_tokenizer.add_special_tokens({'bos_token': '[DEC]'}) return bert_tokenizer, roberta_tokenizer def get_image_processor(): return AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=HF_CACHE_DIR) class AttentionModalityMerger(nn.Module): def __init__(self, text_dim, image_dim): super().__init__() self.text_layer_norm = nn.LayerNorm(text_dim) self.image_layer_norm = nn.LayerNorm(image_dim) self.linear = nn.Linear( in_features=image_dim + text_dim, out_features=1) self.sigmoid = nn.Sigmoid() def forward(self, text_embedds, image_features, attention_mask): input_mask_expanded = attention_mask.unsqueeze( -1).expand(text_embedds.size()).float() text_embedds = input_mask_expanded * text_embedds text_embedds = text_embedds.sum(dim=1) text_embedds_norm = self.text_layer_norm(text_embedds) image_features = image_features.sum(dim=1) image_features_norm = self.image_layer_norm(image_features) text_image_embedds = torch.cat( [text_embedds_norm, image_features_norm], axis=-1) gate_output = self.linear(text_image_embedds) p_txt = self.sigmoid(gate_output) p_img = 1 - p_txt scaled_text = p_txt * text_embedds_norm scaled_image = p_img * image_features_norm final_output = torch.cat([scaled_text, scaled_image], dim=-1) return final_output, p_txt, p_img class RobertaTokenClassificationWithCRF(nn.Module): def __init__(self, vocab_size, device, roberta_token=None): if roberta_token is None: roberta_token = os.getenv("ROBERTA_TOKEN") super().__init__() self.vocab_size = vocab_size self.config = RobertaConfig() self.roberta = RobertaModel.from_pretrained( "FacebookAI/roberta-base", output_hidden_states=True, cache_dir=HF_CACHE_DIR) self.freeze_layers() self._loadTextWeights(device, roberta_token) def _loadTextWeights(self, device, roberta_token): repo_id = "LomaaZakaria/Roberta_Attribute_Value_Extraction_Model" weights_file_name = "RobertaCRFWithNOAnswerClassifier_OnFashionGenData_2epochs.pth" weights_file_path = hf_hub_download( repo_id=repo_id, filename=weights_file_name, token=roberta_token, cache_dir=HF_CACHE_DIR) state_dict = torch.load( weights_file_path, weights_only=True, map_location=device) text_model_state_dict = self.roberta.state_dict() filtered_state_dict = { k: v for k, v in state_dict.items() if k in text_model_state_dict and v.shape == text_model_state_dict[k].shape } self.roberta.load_state_dict(filtered_state_dict, strict=False) def freeze_layers(self): self.roberta.embeddings.requires_grad_(False) for layers in self.roberta.encoder.layer[:8]: for p in layers.parameters(): p.requires_grad = False def forward(self, token_ids, attention_mask): outputs = self.roberta(input_ids=token_ids, attention_mask=attention_mask) last_hidden_state = outputs.hidden_states[-1] return last_hidden_state class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.vision_model = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).vision_model self._freezeLayers() def _freezeLayers(self): self.vision_model.embeddings.requires_grad_(False) for layer in self.vision_model.encoder.layers[:8]: for p in layer.parameters(): p.requires_grad = False def forward(self, x): return self.vision_model(x).last_hidden_state class MergerModel(nn.Module): def __init__(self, vocab_size, device, roberta_token=None): if roberta_token is None: roberta_token = os.getenv("ROBERTA_TOKEN") super().__init__() self.text_decoder = BlipForQuestionAnswering.from_pretrained( "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).text_decoder self.text_encoder = RobertaTokenClassificationWithCRF( vocab_size, device, roberta_token) self.vision_model = ImageModel() text_dim, image_dim = self.text_encoder.config.hidden_size, 768 self.attention_merger = AttentionModalityMerger(text_dim, image_dim) self.linear = nn.Linear(in_features=text_dim + image_dim, out_features=text_dim) def forward(self, **inputs): text_encoder = self.text_encoder( token_ids=inputs['encoder_token_ids'], attention_mask=inputs['encoder_attention_mask']) vision_encoder = self.vision_model(x=inputs['image']) merger_output, p_txt, p_img = self.attention_merger( text_encoder, vision_encoder, attention_mask=inputs['encoder_attention_mask']) merger_output = merger_output.unsqueeze(1) batch_size = vision_encoder.shape[0] merger_output_mask = torch.ones( (batch_size, 1), dtype=torch.long, device=vision_encoder.device) merger_output_linear = self.linear(merger_output) decoder_output = self.text_decoder( input_ids=inputs['decoder_input_token_ids'], attention_mask=inputs['decoder_input_attention_mask'], encoder_hidden_states=merger_output_linear, encoder_attention_mask=merger_output_mask, return_dict=True, return_logits=True ) logits = decoder_output return logits, p_txt, p_img def load_merger_model(bert_tokenizer, device, model_token=None): if model_token is None: model_token = os.getenv("MERGER_MODEL_TOKEN") print("MERGER_MODEL_TOKEN is set:", model_token is not None) vocab_size = len(bert_tokenizer) model = MergerModel(vocab_size, device) repo_id = "MohamedMosilhy/AttentionMergerModality" weights_file_name = "Freezing_More_NewViTBlipAttentionMergerModality_4epochs_2e_5_withwarmup.pth" weights_file_path = hf_hub_download( repo_id=repo_id, filename=weights_file_name, token=model_token, cache_dir=HF_CACHE_DIR) model.load_state_dict(torch.load( weights_file_path, weights_only=True, map_location=device)) model.to(device) model.eval() return model def model_generate(model, data, text_tokenizer, device, labels=None, max_generated_length=50, testing=False, return_confidence=False): if labels is None: labels = '[DEC]' token_labels = text_tokenizer.convert_tokens_to_ids([labels]) else: token_labels = text_tokenizer.convert_tokens_to_ids([labels]) model.eval() confidences = [] for index in range(max_generated_length): decoder_inputs = text_tokenizer( text=labels, max_length=65, padding='max_length', add_special_tokens=False, return_tensors="pt") decoder_data = { "decoder_input_token_ids": decoder_inputs['input_ids'], "decoder_input_attention_mask": decoder_inputs['attention_mask'] } inputs = { "image": data['image'].unsqueeze(0).to(device), "encoder_token_ids": data['encoder_token_ids'].unsqueeze(0).to(device), "encoder_attention_mask": data['encoder_attention_mask'].unsqueeze(0).to(device), "decoder_input_token_ids": decoder_data['decoder_input_token_ids'].to(device), "decoder_input_attention_mask": decoder_data['decoder_input_attention_mask'].to(device) } with torch.no_grad(): logits, _, _ = model(**inputs) probs = F.softmax(logits, dim=-1) predicated_label = torch.argmax( probs[:, index, :], dim=-1).cpu().numpy() # Get confidence for this token confidence = float( probs[0, index, predicated_label[0]].cpu().item()) confidences.append(confidence) token_labels.append(predicated_label[0]) predicted_tokens = text_tokenizer.convert_ids_to_tokens( predicated_label) labels = text_tokenizer.decode(token_labels) if predicted_tokens[0] == text_tokenizer.sep_token: break predicated_attribute_value = text_tokenizer.decode(token_labels) if testing: token_labels = np.array(token_labels) dec_token_id = text_tokenizer.bos_token_id token_labels = token_labels[token_labels != dec_token_id] return token_labels if return_confidence: # Use the minimum confidence across the generated tokens as the attribute confidence return predicated_attribute_value, min(confidences) if confidences else 0.0 return predicated_attribute_value # Define which attributes are relevant for each category CATEGORY_ATTRIBUTES = { "clothing": ['sleeve', 'type', 'pattern', 'material', 'neck', 'color', 'style', 'brand', 'gender'], "bags": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], "shoes": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], "accessories": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], } def get_predicated_values( model, category, img, desc, image_processor, bert_tokenizer, roberta_tokenizer, device, max_seq_length=256 ): results = [] def _combined_with_CategoriesAttributes(desc, category, attribute): return category + ' ' + attribute def imageProcesser(img): return image_processor(img) def _tokenizeText(image, desc, category, attribute): combined_desc = _combined_with_CategoriesAttributes( desc, category, attribute) image_inputs = imageProcesser(image) text_encoder_inputs = roberta_tokenizer( combined_desc, desc, max_length=max_seq_length, padding='max_length', return_tensors='np' ) return image_inputs, text_encoder_inputs # Normalize category to lower-case and pick attributes category_key = str(category).strip().lower() attributes = CATEGORY_ATTRIBUTES.get(category_key, CATEGORY_ATTRIBUTES["clothing"]) image = img for attribute in attributes: image_inputs, text_encoder_inputs = _tokenizeText( image, desc, category, attribute) image_data = torch.from_numpy(np.array(image_inputs['pixel_values'])) encoder_token_ids = torch.from_numpy( np.array(text_encoder_inputs['input_ids'])) encoder_attn_mask = torch.from_numpy( np.array(text_encoder_inputs['attention_mask'])) inputs = { "image": image_data.squeeze(0), "encoder_token_ids": encoder_token_ids.squeeze(0), "encoder_attention_mask": encoder_attn_mask.squeeze(0), } predicated_value, confidence = model_generate( model, inputs, text_tokenizer=bert_tokenizer, device=device, return_confidence=True ) # Remove [DEC] and [SEP] tokens and strip whitespace clean_value = predicated_value.replace('[DEC]', '').replace('[SEP]', '').strip() if clean_value != 'not specified': results.append( {"name": attribute, "value": clean_value, "confidence": float(confidence)} ) return results