Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoConfig | |
| import json | |
| from torch.utils.data import Dataset, DataLoader | |
| instruct_dataset = f'./llava_instruct_150k.json' | |
| with open(instruct_dataset, 'r') as f: | |
| instruct_data = json.load(f) | |
| class CustomTextDataset(Dataset): | |
| def __init__(self, json_data, image_embedding_dict, tokenizer, maxContext=512): | |
| self.image_embedding_dict = image_embedding_dict | |
| self.tokenizer = tokenizer | |
| self.json_data = json_data | |
| self.maxContext = maxContext | |
| self.entries = [] | |
| for entry in json_data: | |
| image = entry['image'] | |
| image_embedding = self.getEmbeddingForImage(image) | |
| if image_embedding is None: | |
| continue | |
| conversations = entry['conversations'] | |
| for i in range(len(conversations)): | |
| if conversations[i]['from'] == 'human': | |
| if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512: | |
| continue | |
| question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n') | |
| answer = 'Answer: ' + conversations[i + 1]['value'] | |
| self.entries.append({ | |
| 'image_name': image, | |
| 'image_embedding': image_embedding, | |
| 'Question': question, | |
| 'Answer': answer, | |
| 'QnAText': question + answer | |
| }) | |
| print('------------- num entries = -----------------') | |
| print(len(self.entries)) | |
| def getEmbeddingForImage(self, image): | |
| if image in self.image_embedding_dict: | |
| image_embedding = self.image_embedding_dict[image] | |
| return image_embedding | |
| else: | |
| return None | |
| def __len__(self): | |
| return len(self.entries) | |
| def __getitem__(self, idx): | |
| entry = self.entries[idx] | |
| image_name = entry['image_name'] | |
| Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True) | |
| QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True) | |
| QTokensLength = len(Q_caption_tokens) | |
| QnA_length = len(QnA_captions_tokens) | |
| QnA_captions_tokens = QnA_captions_tokens + \ | |
| [tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens)) | |
| return {'image_name': entry['image_name'], | |
| 'QText': entry['Question'], | |
| 'AText': entry['Answer'], | |
| 'image_embedding': entry['image_embedding'].to("cuda"), | |
| 'QnA_tokens': torch.tensor(QnA_captions_tokens), | |
| 'QTokensLength': QTokensLength, | |
| 'QnA_length': QnA_length | |
| } | |
| imgEmbDict = torch.load('img_embeddings_dict.pth') | |
| model_name = "microsoft/phi-2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| custom_dataset = CustomTextDataset(instruct_data, imgEmbDict, tokenizer) | |
| custom_dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True) | |