Spaces:
Paused
Paused
| from transformers import (AutoProcessor, | |
| RobertaConfig, | |
| BertTokenizerFast, | |
| RobertaTokenizerFast, | |
| RobertaModel, | |
| BlipForQuestionAnswering) | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| # Load environment variables (optional for local dev; Spaces use web UI for env vars) | |
| if os.path.exists('.env'): | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| ATTRIBUTES_LIST = ['sleeve', 'type', 'pattern', 'material', | |
| 'neck', 'color', 'style', 'brand', 'gender'] | |
| HF_CACHE_DIR = "./hf_cache" | |
| def get_device(): | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def get_tokenizers(): | |
| bert_tokenizer = BertTokenizerFast.from_pretrained( | |
| "google-bert/bert-base-uncased", cache_dir=HF_CACHE_DIR) | |
| roberta_tokenizer = RobertaTokenizerFast.from_pretrained( | |
| "FacebookAI/roberta-base", cache_dir=HF_CACHE_DIR) | |
| bert_tokenizer.add_special_tokens({'bos_token': '[DEC]'}) | |
| return bert_tokenizer, roberta_tokenizer | |
| def get_image_processor(): | |
| return AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=HF_CACHE_DIR) | |
| class AttentionModalityMerger(nn.Module): | |
| def __init__(self, text_dim, image_dim): | |
| super().__init__() | |
| self.text_layer_norm = nn.LayerNorm(text_dim) | |
| self.image_layer_norm = nn.LayerNorm(image_dim) | |
| self.linear = nn.Linear( | |
| in_features=image_dim + text_dim, out_features=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, text_embedds, image_features, attention_mask): | |
| input_mask_expanded = attention_mask.unsqueeze( | |
| -1).expand(text_embedds.size()).float() | |
| text_embedds = input_mask_expanded * text_embedds | |
| text_embedds = text_embedds.sum(dim=1) | |
| text_embedds_norm = self.text_layer_norm(text_embedds) | |
| image_features = image_features.sum(dim=1) | |
| image_features_norm = self.image_layer_norm(image_features) | |
| text_image_embedds = torch.cat( | |
| [text_embedds_norm, image_features_norm], axis=-1) | |
| gate_output = self.linear(text_image_embedds) | |
| p_txt = self.sigmoid(gate_output) | |
| p_img = 1 - p_txt | |
| scaled_text = p_txt * text_embedds_norm | |
| scaled_image = p_img * image_features_norm | |
| final_output = torch.cat([scaled_text, scaled_image], dim=-1) | |
| return final_output, p_txt, p_img | |
| class RobertaTokenClassificationWithCRF(nn.Module): | |
| def __init__(self, vocab_size, device, roberta_token=None): | |
| if roberta_token is None: | |
| roberta_token = os.getenv("ROBERTA_TOKEN") | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.config = RobertaConfig() | |
| self.roberta = RobertaModel.from_pretrained( | |
| "FacebookAI/roberta-base", output_hidden_states=True, cache_dir=HF_CACHE_DIR) | |
| self.freeze_layers() | |
| self._loadTextWeights(device, roberta_token) | |
| def _loadTextWeights(self, device, roberta_token): | |
| repo_id = "LomaaZakaria/Roberta_Attribute_Value_Extraction_Model" | |
| weights_file_name = "RobertaCRFWithNOAnswerClassifier_OnFashionGenData_2epochs.pth" | |
| weights_file_path = hf_hub_download( | |
| repo_id=repo_id, filename=weights_file_name, token=roberta_token, cache_dir=HF_CACHE_DIR) | |
| state_dict = torch.load( | |
| weights_file_path, weights_only=True, map_location=device) | |
| text_model_state_dict = self.roberta.state_dict() | |
| filtered_state_dict = { | |
| k: v for k, v in state_dict.items() | |
| if k in text_model_state_dict and v.shape == text_model_state_dict[k].shape | |
| } | |
| self.roberta.load_state_dict(filtered_state_dict, strict=False) | |
| def freeze_layers(self): | |
| self.roberta.embeddings.requires_grad_(False) | |
| for layers in self.roberta.encoder.layer[:8]: | |
| for p in layers.parameters(): | |
| p.requires_grad = False | |
| def forward(self, token_ids, attention_mask): | |
| outputs = self.roberta(input_ids=token_ids, | |
| attention_mask=attention_mask) | |
| last_hidden_state = outputs.hidden_states[-1] | |
| return last_hidden_state | |
| class ImageModel(nn.Module): | |
| def __init__(self): | |
| super(ImageModel, self).__init__() | |
| self.vision_model = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).vision_model | |
| self._freezeLayers() | |
| def _freezeLayers(self): | |
| self.vision_model.embeddings.requires_grad_(False) | |
| for layer in self.vision_model.encoder.layers[:8]: | |
| for p in layer.parameters(): | |
| p.requires_grad = False | |
| def forward(self, x): | |
| return self.vision_model(x).last_hidden_state | |
| class MergerModel(nn.Module): | |
| def __init__(self, vocab_size, device, roberta_token=None): | |
| if roberta_token is None: | |
| roberta_token = os.getenv("ROBERTA_TOKEN") | |
| super().__init__() | |
| self.text_decoder = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base", cache_dir=HF_CACHE_DIR).text_decoder | |
| self.text_encoder = RobertaTokenClassificationWithCRF( | |
| vocab_size, device, roberta_token) | |
| self.vision_model = ImageModel() | |
| text_dim, image_dim = self.text_encoder.config.hidden_size, 768 | |
| self.attention_merger = AttentionModalityMerger(text_dim, image_dim) | |
| self.linear = nn.Linear(in_features=text_dim + | |
| image_dim, out_features=text_dim) | |
| def forward(self, **inputs): | |
| text_encoder = self.text_encoder( | |
| token_ids=inputs['encoder_token_ids'], attention_mask=inputs['encoder_attention_mask']) | |
| vision_encoder = self.vision_model(x=inputs['image']) | |
| merger_output, p_txt, p_img = self.attention_merger( | |
| text_encoder, vision_encoder, attention_mask=inputs['encoder_attention_mask']) | |
| merger_output = merger_output.unsqueeze(1) | |
| batch_size = vision_encoder.shape[0] | |
| merger_output_mask = torch.ones( | |
| (batch_size, 1), dtype=torch.long, device=vision_encoder.device) | |
| merger_output_linear = self.linear(merger_output) | |
| decoder_output = self.text_decoder( | |
| input_ids=inputs['decoder_input_token_ids'], | |
| attention_mask=inputs['decoder_input_attention_mask'], | |
| encoder_hidden_states=merger_output_linear, | |
| encoder_attention_mask=merger_output_mask, | |
| return_dict=True, | |
| return_logits=True | |
| ) | |
| logits = decoder_output | |
| return logits, p_txt, p_img | |
| def load_merger_model(bert_tokenizer, device, model_token=None): | |
| if model_token is None: | |
| model_token = os.getenv("MERGER_MODEL_TOKEN") | |
| print("MERGER_MODEL_TOKEN is set:", model_token is not None) | |
| vocab_size = len(bert_tokenizer) | |
| model = MergerModel(vocab_size, device) | |
| repo_id = "MohamedMosilhy/AttentionMergerModality" | |
| weights_file_name = "Freezing_More_NewViTBlipAttentionMergerModality_4epochs_2e_5_withwarmup.pth" | |
| weights_file_path = hf_hub_download( | |
| repo_id=repo_id, filename=weights_file_name, token=model_token, cache_dir=HF_CACHE_DIR) | |
| model.load_state_dict(torch.load( | |
| weights_file_path, weights_only=True, map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def model_generate(model, data, text_tokenizer, device, labels=None, max_generated_length=50, testing=False, return_confidence=False): | |
| if labels is None: | |
| labels = '[DEC]' | |
| token_labels = text_tokenizer.convert_tokens_to_ids([labels]) | |
| else: | |
| token_labels = text_tokenizer.convert_tokens_to_ids([labels]) | |
| model.eval() | |
| confidences = [] | |
| for index in range(max_generated_length): | |
| decoder_inputs = text_tokenizer( | |
| text=labels, max_length=65, padding='max_length', add_special_tokens=False, return_tensors="pt") | |
| decoder_data = { | |
| "decoder_input_token_ids": decoder_inputs['input_ids'], | |
| "decoder_input_attention_mask": decoder_inputs['attention_mask'] | |
| } | |
| inputs = { | |
| "image": data['image'].unsqueeze(0).to(device), | |
| "encoder_token_ids": data['encoder_token_ids'].unsqueeze(0).to(device), | |
| "encoder_attention_mask": data['encoder_attention_mask'].unsqueeze(0).to(device), | |
| "decoder_input_token_ids": decoder_data['decoder_input_token_ids'].to(device), | |
| "decoder_input_attention_mask": decoder_data['decoder_input_attention_mask'].to(device) | |
| } | |
| with torch.no_grad(): | |
| logits, _, _ = model(**inputs) | |
| probs = F.softmax(logits, dim=-1) | |
| predicated_label = torch.argmax( | |
| probs[:, index, :], dim=-1).cpu().numpy() | |
| # Get confidence for this token | |
| confidence = float( | |
| probs[0, index, predicated_label[0]].cpu().item()) | |
| confidences.append(confidence) | |
| token_labels.append(predicated_label[0]) | |
| predicted_tokens = text_tokenizer.convert_ids_to_tokens( | |
| predicated_label) | |
| labels = text_tokenizer.decode(token_labels) | |
| if predicted_tokens[0] == text_tokenizer.sep_token: | |
| break | |
| predicated_attribute_value = text_tokenizer.decode(token_labels) | |
| if testing: | |
| token_labels = np.array(token_labels) | |
| dec_token_id = text_tokenizer.bos_token_id | |
| token_labels = token_labels[token_labels != dec_token_id] | |
| return token_labels | |
| if return_confidence: | |
| # Use the minimum confidence across the generated tokens as the attribute confidence | |
| return predicated_attribute_value, min(confidences) if confidences else 0.0 | |
| return predicated_attribute_value | |
| # Define which attributes are relevant for each category | |
| CATEGORY_ATTRIBUTES = { | |
| "clothing": ['sleeve', 'type', 'pattern', 'material', 'neck', 'color', 'style', 'brand', 'gender'], | |
| "bags": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], | |
| "shoes": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], | |
| "accessories": ['type', 'pattern', 'material', 'color', 'style', 'brand', 'gender'], | |
| } | |
| def get_predicated_values( | |
| model, category, img, desc, image_processor, bert_tokenizer, roberta_tokenizer, device, max_seq_length=256 | |
| ): | |
| results = [] | |
| def _combined_with_CategoriesAttributes(desc, category, attribute): | |
| return category + ' ' + attribute | |
| def imageProcesser(img): | |
| return image_processor(img) | |
| def _tokenizeText(image, desc, category, attribute): | |
| combined_desc = _combined_with_CategoriesAttributes( | |
| desc, category, attribute) | |
| image_inputs = imageProcesser(image) | |
| text_encoder_inputs = roberta_tokenizer( | |
| combined_desc, | |
| desc, | |
| max_length=max_seq_length, | |
| padding='max_length', | |
| return_tensors='np' | |
| ) | |
| return image_inputs, text_encoder_inputs | |
| # Normalize category to lower-case and pick attributes | |
| category_key = str(category).strip().lower() | |
| attributes = CATEGORY_ATTRIBUTES.get(category_key, CATEGORY_ATTRIBUTES["clothing"]) | |
| image = img | |
| for attribute in attributes: | |
| image_inputs, text_encoder_inputs = _tokenizeText( | |
| image, desc, category, attribute) | |
| image_data = torch.from_numpy(np.array(image_inputs['pixel_values'])) | |
| encoder_token_ids = torch.from_numpy( | |
| np.array(text_encoder_inputs['input_ids'])) | |
| encoder_attn_mask = torch.from_numpy( | |
| np.array(text_encoder_inputs['attention_mask'])) | |
| inputs = { | |
| "image": image_data.squeeze(0), | |
| "encoder_token_ids": encoder_token_ids.squeeze(0), | |
| "encoder_attention_mask": encoder_attn_mask.squeeze(0), | |
| } | |
| predicated_value, confidence = model_generate( | |
| model, inputs, text_tokenizer=bert_tokenizer, device=device, return_confidence=True | |
| ) | |
| # Remove [DEC] and [SEP] tokens and strip whitespace | |
| clean_value = predicated_value.replace('[DEC]', '').replace('[SEP]', '').strip() | |
| if clean_value != 'not specified': | |
| results.append( | |
| {"name": attribute, "value": clean_value, | |
| "confidence": float(confidence)} | |
| ) | |
| return results | |