Upload 7 files
Browse files- README.md +57 -13
- configs.py +33 -0
- dataset.py +204 -0
- get_coco.py +41 -0
- main.py +41 -0
- model.py +378 -0
- requirements.txt +7 -0
README.md
CHANGED
|
@@ -1,13 +1,57 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-Modal LLM Gradio App
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
|
| 5 |
+
This project is a **multi-modal language model** Gradio app that accepts **text**, **image**, and **audio inputs**, and outputs **text responses**. The app mimics a **ChatGPT-style interface**, allowing users to interact using multiple input modes.
|
| 6 |
+
|
| 7 |
+
The app leverages:
|
| 8 |
+
- **CLIP** for image processing
|
| 9 |
+
- **Whisper** for audio transcription (ASR)
|
| 10 |
+
- A **text-based model** (like GPT or Phi) for generating text responses
|
| 11 |
+
|
| 12 |
+
## Features
|
| 13 |
+
|
| 14 |
+
- **Text Input**: Users can input text directly for response generation.
|
| 15 |
+
- **Image Input**: Users can upload images, which are processed by the CLIP model.
|
| 16 |
+
- **Audio Input**: Users can upload or record audio files, which are transcribed by the Whisper model and then processed for response.
|
| 17 |
+
- **ChatGPT-Like Interface**: Simple and intuitive interface to handle multi-modal inputs and provide text-based output.
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
1. Clone the repository:
|
| 22 |
+
```bash
|
| 23 |
+
git clone https://huggingface.co/spaces/Vasudevakrishna/MultiModel_LLM_ERAV2
|
| 24 |
+
cd MultiModel_LLM_ERAV2
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2. Install dependencies:
|
| 28 |
+
```bash
|
| 29 |
+
pip -r requirements.txt
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
3. Run the app:
|
| 33 |
+
```bash
|
| 34 |
+
python app.py
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## How It Works
|
| 38 |
+
|
| 39 |
+
1. **Text Processing**: Input text is passed to a language model (like GPT or Phi) to generate a response.
|
| 40 |
+
2. **Image Processing**: Images are processed using CLIP, which extracts embeddings. These embeddings are then converted into a format understandable by the text model.
|
| 41 |
+
3. **Audio Processing**: Audio files are transcribed into text using Whisper. This text is passed into the language model for response generation.
|
| 42 |
+
|
| 43 |
+
## Usage
|
| 44 |
+
|
| 45 |
+
- **Text Input**: Enter text in the provided textbox and click "Submit" to generate a response.
|
| 46 |
+
- **Image Input**: Upload an image and click "Submit" to generate a response based on the image.
|
| 47 |
+
- **Audio Input**: Upload or record an audio file, click "Submit" to transcribe and generate a response.
|
| 48 |
+
|
| 49 |
+
## Future Improvements
|
| 50 |
+
|
| 51 |
+
- Add advanced features like drag-and-drop file upload or live audio recording for a better user experience.
|
| 52 |
+
- Improve the real-time image embedding process by running CLIP embeddings in real-time with more GPU resources.
|
| 53 |
+
- Implement end-to-end training of all components for better response quality.
|
| 54 |
+
|
| 55 |
+
## License
|
| 56 |
+
|
| 57 |
+
This project is licensed under the MIT License.
|
configs.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
def get_config_phase1():
|
| 3 |
+
return {
|
| 4 |
+
"data_dir": "./data",
|
| 5 |
+
"clip_model_name": "openai/clip-vit-base-patch16",
|
| 6 |
+
"phi2_model_name": "microsoft/phi-2",
|
| 7 |
+
"train_batch_size": 2,
|
| 8 |
+
"val_batch_size": 1,
|
| 9 |
+
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 10 |
+
"epochs": 2,
|
| 11 |
+
"max_tokens": 20,
|
| 12 |
+
"clip_embed": 768,
|
| 13 |
+
"phi_embed": 2560,
|
| 14 |
+
"num_workers": 4,
|
| 15 |
+
"ckpts": "./ckpts"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
def get_config_phase2():
|
| 19 |
+
return {
|
| 20 |
+
"data_dir": "./data",
|
| 21 |
+
"clip_model_name": "openai/clip-vit-base-patch16",
|
| 22 |
+
"phi2_model_name": "microsoft/phi-2",
|
| 23 |
+
"train_batch_size": 1,
|
| 24 |
+
"val_batch_size": 1,
|
| 25 |
+
"device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 26 |
+
"epochs": 10,
|
| 27 |
+
"max_tokens": 100,
|
| 28 |
+
"clip_embed": 768,
|
| 29 |
+
"phi_embed": 2560,
|
| 30 |
+
"num_workers": 0,
|
| 31 |
+
"ckpts": "./ckpts",
|
| 32 |
+
"vocab_size": 51200
|
| 33 |
+
}
|
dataset.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from transformers import AutoProcessor
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import pickle
|
| 9 |
+
import requests
|
| 10 |
+
from datasets import Dataset, load_dataset
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ClipDataset(Dataset):
|
| 16 |
+
'''ClipDataset class for loading the CLIP dataset'''
|
| 17 |
+
def __init__(self, coco_data, model_name, tokenizer):
|
| 18 |
+
|
| 19 |
+
self.tokenizer = tokenizer
|
| 20 |
+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
| 21 |
+
self.caption_dataset = coco_data
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
#Return the length of the dataset
|
| 25 |
+
return len(self.caption_dataset)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, idx):
|
| 28 |
+
#Get the image url and caption
|
| 29 |
+
img_url = self.caption_dataset[idx]["image_url"]
|
| 30 |
+
caption = self.caption_dataset[idx]["caption"]
|
| 31 |
+
|
| 32 |
+
#Get the image and caption embeddings
|
| 33 |
+
image = Image.open(requests.get(img_url,stream=True).raw)
|
| 34 |
+
width, height = image.size
|
| 35 |
+
new_width = 224
|
| 36 |
+
new_height = new_width * height // width
|
| 37 |
+
new_height = 224
|
| 38 |
+
new_width = new_height * width // height
|
| 39 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 40 |
+
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
| 41 |
+
image_sqeezed = image_processed.squeeze(0)
|
| 42 |
+
tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False)
|
| 43 |
+
tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0)
|
| 44 |
+
return(image_sqeezed , tokenized_caption_ids)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def collate_fn_phase1(batch):
|
| 48 |
+
#Unzip the batch
|
| 49 |
+
image_embeddings, captions = zip(*batch)
|
| 50 |
+
#Stack the image embeddings
|
| 51 |
+
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
| 52 |
+
#Pad the captions, padded value is the <eos> token
|
| 53 |
+
captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256)
|
| 54 |
+
#Return the stacked image embeddings and padded captions
|
| 55 |
+
return (image_embeddings_stacked, captions_padded)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers):
|
| 59 |
+
# Load the data
|
| 60 |
+
with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp:
|
| 61 |
+
train_pkl = pickle.load(fp)
|
| 62 |
+
with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp:
|
| 63 |
+
val_pkl = pickle.load(fp)
|
| 64 |
+
# train data loaders
|
| 65 |
+
train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)
|
| 66 |
+
|
| 67 |
+
# val data loaders
|
| 68 |
+
val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True)
|
| 69 |
+
return train_dataloader, val_dataloader
|
| 70 |
+
|
| 71 |
+
##################################### Phase 2 #########################################
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ClipDatasetPhase2(Dataset):
|
| 75 |
+
'''ClipDataset class for loading the CLIP dataset'''
|
| 76 |
+
def __init__(self, data_frame, model_name, tokenizer):
|
| 77 |
+
|
| 78 |
+
self.tokenizer = tokenizer
|
| 79 |
+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
| 80 |
+
self.df = data_frame
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
#Return the length of the dataset
|
| 84 |
+
return len(self.df)
|
| 85 |
+
|
| 86 |
+
def __getitem__(self, idx):
|
| 87 |
+
#Get the image url and QAs
|
| 88 |
+
img_url = self.df.ImageUrl[idx[0]]
|
| 89 |
+
que = self.df.Question[idx[0]]
|
| 90 |
+
ans = self.df.Answer[idx[0]]
|
| 91 |
+
|
| 92 |
+
print("img_url", img_url)
|
| 93 |
+
print("que", que)
|
| 94 |
+
print("ans", ans)
|
| 95 |
+
|
| 96 |
+
#Get the image and caption embeddings
|
| 97 |
+
if img_url is None:
|
| 98 |
+
print("img_url is None")
|
| 99 |
+
image_sqeezed = None
|
| 100 |
+
else:
|
| 101 |
+
image = Image.open(requests.get(img_url,stream=True).raw)
|
| 102 |
+
width, height = image.size
|
| 103 |
+
new_width = 224
|
| 104 |
+
new_height = new_width * height // width
|
| 105 |
+
new_height = 224
|
| 106 |
+
new_width = new_height * width // height
|
| 107 |
+
image = image.resize((new_width, new_height), Image.LANCZOS)
|
| 108 |
+
image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
|
| 109 |
+
image_sqeezed = image_processed.squeeze(0)
|
| 110 |
+
que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 111 |
+
ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
|
| 112 |
+
return(image_sqeezed , que_ids, ans_ids)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def collate_fn_phase2(batch):
|
| 116 |
+
#Unzip the batch
|
| 117 |
+
image_embeddings, ques, ans = zip(*batch)
|
| 118 |
+
#Stack the image embeddings
|
| 119 |
+
if image_embeddings[0] is None:
|
| 120 |
+
image_embeddings_stacked = None
|
| 121 |
+
else:
|
| 122 |
+
image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
|
| 123 |
+
#Pad the QAs, padded value is the <eos> token
|
| 124 |
+
ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256)
|
| 125 |
+
ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256)
|
| 126 |
+
#Return the stacked image embeddings and padded QAs
|
| 127 |
+
return (image_embeddings_stacked, ques_padded, ans_padded)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def prep_data(df):
|
| 131 |
+
df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy()
|
| 132 |
+
df_prompter = df[(df.role == "prompter")].copy()
|
| 133 |
+
df_prompter = df_prompter.set_index("message_id")
|
| 134 |
+
df_assistant["Answer"] = df_assistant["text"].values
|
| 135 |
+
|
| 136 |
+
inputs = []
|
| 137 |
+
for _, row in df_assistant.iterrows():
|
| 138 |
+
input = df_prompter.loc[row.parent_id]
|
| 139 |
+
inputs.append(input.text)
|
| 140 |
+
|
| 141 |
+
df_assistant["Question"] = inputs
|
| 142 |
+
df_assistant["ImageUrl"] = None
|
| 143 |
+
|
| 144 |
+
df_assistant = df_assistant[df_assistant.lang == "en"]
|
| 145 |
+
|
| 146 |
+
df_assistant = df_assistant[
|
| 147 |
+
["ImageUrl","Question", "Answer", "message_id"]
|
| 148 |
+
].rename(columns={"message_id": "Ids"})
|
| 149 |
+
|
| 150 |
+
return df_assistant
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_i150_df(config):
|
| 154 |
+
with open(config.get("i150k_json"), "r") as fp:
|
| 155 |
+
i150k_json_read = json.load(fp)
|
| 156 |
+
max_tokens = 100
|
| 157 |
+
image_urls = []
|
| 158 |
+
ques_list = []
|
| 159 |
+
ans_list = []
|
| 160 |
+
id_list = []
|
| 161 |
+
for idx, data in enumerate(i150k_json_read):
|
| 162 |
+
image = data['image']
|
| 163 |
+
image_url = 'http://images.cocodataset.org/train2017/' + image
|
| 164 |
+
id_ = data["id"]
|
| 165 |
+
iterator = iter(data['conversations'])
|
| 166 |
+
for i in iterator:
|
| 167 |
+
ques = i
|
| 168 |
+
ans = next(iterator)
|
| 169 |
+
if (len(ques["value"])>100 or len(ans["value"])>max_tokens):
|
| 170 |
+
continue
|
| 171 |
+
if ques["from"] == "human" and ans["from"] == "gpt":
|
| 172 |
+
image_urls.append(image_url)
|
| 173 |
+
ques_list.append(ques["value"].replace("<image>\n","").replace("<image>",""))
|
| 174 |
+
ans_list.append(ans["value"])
|
| 175 |
+
id_list.append(id_)
|
| 176 |
+
df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)),
|
| 177 |
+
columns =["ImageUrl", "Question", "Answer", "Ids"])
|
| 178 |
+
msk = np.random.rand(len(df_i150k)) < 0.96
|
| 179 |
+
|
| 180 |
+
train_df = df_i150k[msk]
|
| 181 |
+
test_df = df_i150k[~msk]
|
| 182 |
+
return train_df, test_df
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_oas_df(config):
|
| 186 |
+
train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"])
|
| 187 |
+
train_df = prep_data(train_ds.to_pandas())
|
| 188 |
+
test_df = prep_data(val_ds.to_pandas())
|
| 189 |
+
return train_df, test_df
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_data_loaders_phase2(tokenizer, config):
|
| 193 |
+
|
| 194 |
+
train_i150k, test_i150k = get_i150_df(config)
|
| 195 |
+
train_oas, test_oas = get_oas_df(config)
|
| 196 |
+
|
| 197 |
+
train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True)
|
| 198 |
+
val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True)
|
| 199 |
+
# train data loaders
|
| 200 |
+
train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True)
|
| 201 |
+
|
| 202 |
+
# val data loaders
|
| 203 |
+
val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True)
|
| 204 |
+
return train_dataloader, val_dataloader
|
get_coco.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, shutil, json
|
| 2 |
+
import pickle, argparse
|
| 3 |
+
|
| 4 |
+
"""Unzip the data and and save it as a pickle file."""
|
| 5 |
+
|
| 6 |
+
def make_pkl(data_dir, dataset_json, train_flag=False):
|
| 7 |
+
coco_data_list = []
|
| 8 |
+
for i, data in enumerate(dataset_json['annotations']):
|
| 9 |
+
image_id = data['image_id']
|
| 10 |
+
caption = data['caption']
|
| 11 |
+
for img in dataset_json['images']:
|
| 12 |
+
if img['id'] == image_id:
|
| 13 |
+
image_url = img['coco_url']
|
| 14 |
+
file_name = img['file_name']
|
| 15 |
+
break
|
| 16 |
+
coco_data_list.append({'image_id': image_id,'image_url': image_url, 'file_name': file_name, 'caption': caption})
|
| 17 |
+
if train_flag:
|
| 18 |
+
with open(os.path.join(data_dir, f'coco_train.pkl'), 'wb') as f:
|
| 19 |
+
pickle.dump(coco_data_list, f)
|
| 20 |
+
else:
|
| 21 |
+
with open(os.path.join(data_dir, f'coco_val.pkl'), 'wb') as f:
|
| 22 |
+
pickle.dump(coco_data_list, f)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main(coco_path, data_dir):
|
| 26 |
+
coco_dir = os.path.dirname(coco_path)
|
| 27 |
+
# shutil.unpack_archive(coco_path, coco_dir)
|
| 28 |
+
with open(os.path.join(coco_dir, 'annotations/captions_train2017.json')) as f:
|
| 29 |
+
coco_train_dataset = json.load(f)
|
| 30 |
+
with open(os.path.join(coco_dir, 'annotations/captions_val2017.json')) as f:
|
| 31 |
+
coco_val_dataset = json.load(f)
|
| 32 |
+
make_pkl(data_dir, coco_train_dataset, train_flag=True)
|
| 33 |
+
# make_pkl(data_dir, coco_val_dataset)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
parser = argparse.ArgumentParser()
|
| 38 |
+
parser.add_argument('--coco_path', type=str, default='coco.zip')
|
| 39 |
+
parser.add_argument('--data_dir', type=str, default='data')
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
main(args.coco_path, args.data_dir)
|
main.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dataset import get_data_loaders_phase1, get_data_loaders_phase2
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
|
| 5 |
+
from configs import get_config_phase1, get_config_phase2
|
| 6 |
+
|
| 7 |
+
def phase_1():
|
| 8 |
+
# get config
|
| 9 |
+
config = get_config_phase1()
|
| 10 |
+
# tokenizer
|
| 11 |
+
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
|
| 12 |
+
|
| 13 |
+
# data loaders
|
| 14 |
+
train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))
|
| 15 |
+
|
| 16 |
+
llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
|
| 17 |
+
print(llmModel)
|
| 18 |
+
# optimizer
|
| 19 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
|
| 20 |
+
# train model
|
| 21 |
+
train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def phase_2():
|
| 25 |
+
# get config
|
| 26 |
+
config = get_config_phase2()
|
| 27 |
+
# tokenizer
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
|
| 29 |
+
|
| 30 |
+
# data loaders
|
| 31 |
+
train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)
|
| 32 |
+
|
| 33 |
+
llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
|
| 34 |
+
print(llmModel)
|
| 35 |
+
# train model
|
| 36 |
+
train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
torch.set_float32_matmul_precision('medium')
|
| 40 |
+
phase_1()
|
| 41 |
+
# phase_2()
|
model.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn.functional import cross_entropy
|
| 4 |
+
from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
|
| 5 |
+
from peft import LoraConfig
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import os, peft
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CustomClipPhi2(nn.Module):
|
| 11 |
+
def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.tokenizer = tokenizer
|
| 15 |
+
# These two models are not finetuned
|
| 16 |
+
# pretrained Microsoft phi2 model
|
| 17 |
+
self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
|
| 18 |
+
# pretrained OpenAI clip model
|
| 19 |
+
self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
|
| 20 |
+
|
| 21 |
+
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256
|
| 22 |
+
self.IMAGE_TOKEN_ID = 23903 # token for Comments
|
| 23 |
+
self.clip_embed = clip_embed
|
| 24 |
+
self.phi_embed = phi_embed
|
| 25 |
+
|
| 26 |
+
# projection layers
|
| 27 |
+
# Trainable projection layer
|
| 28 |
+
self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
|
| 29 |
+
|
| 30 |
+
# Freeze Weights
|
| 31 |
+
for models in [self.phi2_model, self.clip_model]:
|
| 32 |
+
for param in models.parameters():
|
| 33 |
+
param.requires_grad_(False)
|
| 34 |
+
|
| 35 |
+
# load checkpoint weights
|
| 36 |
+
if os.path.exists('./ckpts/model_phase1.pth'):
|
| 37 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
|
| 38 |
+
print("Loaded checkpoint weights for projection layer")
|
| 39 |
+
else:
|
| 40 |
+
print("No checkpoint weights for projection layer")
|
| 41 |
+
print("Initializing projection layer with random weights")
|
| 42 |
+
self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
|
| 43 |
+
self.projection_layer.bias.data.zero_()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def generate(self, images, tokenizer, config):
|
| 47 |
+
clip_outputs = self.clip_model(**images)
|
| 48 |
+
# remove cls token
|
| 49 |
+
images = clip_outputs.last_hidden_state[:, 1:, :]
|
| 50 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
| 51 |
+
|
| 52 |
+
batch_size = images.size()[0]
|
| 53 |
+
predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
|
| 54 |
+
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
| 55 |
+
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
| 56 |
+
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
|
| 57 |
+
|
| 58 |
+
for pos in range(config.get("max_tokens") - 1):
|
| 59 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 60 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 61 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
| 62 |
+
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
| 63 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
| 64 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
| 65 |
+
return predicted_caption
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def forward(self, images, target_captions):
|
| 69 |
+
|
| 70 |
+
batch_size = target_captions.size()[0]
|
| 71 |
+
target_length = target_captions.size()[1]
|
| 72 |
+
print("---", target_length)
|
| 73 |
+
|
| 74 |
+
# clip model output for image
|
| 75 |
+
clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
|
| 76 |
+
images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token
|
| 77 |
+
|
| 78 |
+
# projection layer
|
| 79 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
| 80 |
+
|
| 81 |
+
# add comment token from phi2
|
| 82 |
+
img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
|
| 83 |
+
img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
|
| 84 |
+
combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560
|
| 85 |
+
del clip_outputs
|
| 86 |
+
del image_embeddings
|
| 87 |
+
|
| 88 |
+
# for loss
|
| 89 |
+
loss = 0
|
| 90 |
+
for pos in range(target_length - 1):
|
| 91 |
+
|
| 92 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 93 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 94 |
+
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
| 95 |
+
loss += pos_loss
|
| 96 |
+
|
| 97 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
|
| 98 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
| 99 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
| 100 |
+
loss = loss / target_length
|
| 101 |
+
|
| 102 |
+
# Delete variables to free up memory
|
| 103 |
+
del combined_embeds
|
| 104 |
+
del model_output_logits
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
|
| 111 |
+
model.eval()
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
for i in range(num_samples):
|
| 114 |
+
for images, target_captions in val_dataloader:
|
| 115 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
| 116 |
+
target_captions = target_captions.to(config.get('device'))
|
| 117 |
+
target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
|
| 118 |
+
predicted_captions = model.generate(images, tokenizer, config)
|
| 119 |
+
predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
|
| 120 |
+
|
| 121 |
+
for idx, pc in enumerate(predicted_captions_decoded):
|
| 122 |
+
print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def validate_model_phase1(model, val_dataloader, tokenizer, config):
|
| 127 |
+
model.eval()
|
| 128 |
+
total_loss = 0
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
try:
|
| 131 |
+
for images, target_captions in tqdm(val_dataloader):
|
| 132 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
| 133 |
+
target_captions = target_captions.to(config.get('device'))
|
| 134 |
+
loss = model(images, target_captions)
|
| 135 |
+
total_loss+=loss.item()
|
| 136 |
+
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
| 137 |
+
except Exception as e:
|
| 138 |
+
pass
|
| 139 |
+
model.train()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
|
| 143 |
+
model.train()
|
| 144 |
+
|
| 145 |
+
pbar = tqdm(train_loader)
|
| 146 |
+
for epoch in range(1, config.get("epochs")):
|
| 147 |
+
print(f"Epoch: {epoch}")
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
step = 1
|
| 150 |
+
try:
|
| 151 |
+
for idx, (images, target_captions) in enumerate(pbar):
|
| 152 |
+
try:
|
| 153 |
+
if target_captions.shape[1] >= config.get("max_tokens"):
|
| 154 |
+
# print(f"Skipping batch {idx} due to long caption")
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
images = {'pixel_values': images.to(config.get('device'))}
|
| 158 |
+
target_captions = target_captions.to(config.get('device'))
|
| 159 |
+
|
| 160 |
+
optimizer.zero_grad()
|
| 161 |
+
loss = model(images, target_captions)
|
| 162 |
+
loss.backward()
|
| 163 |
+
optimizer.step()
|
| 164 |
+
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
| 165 |
+
torch.cuda.empty_cache()
|
| 166 |
+
step+=1
|
| 167 |
+
if (step%1000==0):
|
| 168 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(e)
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
# # save model
|
| 174 |
+
# if ((epoch % 2) == 0):
|
| 175 |
+
# Only save last checkpoint
|
| 176 |
+
validate_model_phase1(model, val_dataloader, tokenizer, config)
|
| 177 |
+
show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
|
| 178 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(e)
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
######################################## Phase 2 #########################################
|
| 188 |
+
|
| 189 |
+
class MainQLoraModel(nn.Module):
|
| 190 |
+
def __init__(self, tokenizer, config):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.tokenizer = tokenizer
|
| 193 |
+
self.config = config
|
| 194 |
+
self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
|
| 195 |
+
|
| 196 |
+
bnb_config = BitsAndBytesConfig(
|
| 197 |
+
load_in_4bit=True,
|
| 198 |
+
bnb_4bit_quant_type="nf4",
|
| 199 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
phi2_model = AutoModelForCausalLM.from_pretrained(
|
| 203 |
+
config.get("phi2_model_name"),
|
| 204 |
+
quantization_config=bnb_config,
|
| 205 |
+
trust_remote_code=True
|
| 206 |
+
)
|
| 207 |
+
phi2_model.config.use_cache = False
|
| 208 |
+
|
| 209 |
+
## 4 - LORA config
|
| 210 |
+
|
| 211 |
+
lora_alpha = 16
|
| 212 |
+
lora_dropout = 0.1
|
| 213 |
+
lora_r = 64
|
| 214 |
+
|
| 215 |
+
peft_config = LoraConfig(
|
| 216 |
+
lora_alpha = lora_alpha,
|
| 217 |
+
lora_dropout = lora_dropout,
|
| 218 |
+
r = lora_r,
|
| 219 |
+
bias="none",
|
| 220 |
+
task_type="CAUSAL_LM",
|
| 221 |
+
target_modules=[
|
| 222 |
+
"q_proj",
|
| 223 |
+
"k_proj",
|
| 224 |
+
"v_proj",
|
| 225 |
+
"dense",
|
| 226 |
+
"fc1",
|
| 227 |
+
"fc2"
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
|
| 231 |
+
|
| 232 |
+
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
| 233 |
+
self.clip_embed = config.get("clip_embed")
|
| 234 |
+
self.phi_embed = config.get("phi_embed")
|
| 235 |
+
|
| 236 |
+
# projection layers
|
| 237 |
+
# Trainable projection layer
|
| 238 |
+
self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
|
| 239 |
+
|
| 240 |
+
# Freeze Weights
|
| 241 |
+
for models in [self.clip_model]:
|
| 242 |
+
for param in models.parameters():
|
| 243 |
+
param.requires_grad_(False)
|
| 244 |
+
|
| 245 |
+
# load checkpoint weights
|
| 246 |
+
if os.path.exists('./ckpts/model_phase2.pth'):
|
| 247 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
|
| 248 |
+
self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
|
| 249 |
+
print("Loaded checkpoint weights for projection layer")
|
| 250 |
+
else:
|
| 251 |
+
# Load weights from phase 1
|
| 252 |
+
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
|
| 256 |
+
batch_size = 1
|
| 257 |
+
|
| 258 |
+
predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
|
| 259 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
| 260 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
| 261 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
| 262 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
| 263 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
| 264 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
| 265 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(ques)
|
| 266 |
+
if images is not None:
|
| 267 |
+
clip_outputs = self.clip_model(**images)
|
| 268 |
+
# remove cls token
|
| 269 |
+
images = clip_outputs.last_hidden_state[:, 1:, :]
|
| 270 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
| 271 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
|
| 272 |
+
else:
|
| 273 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
|
| 274 |
+
|
| 275 |
+
for pos in range(max_tokens - 1):
|
| 276 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 277 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 278 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
| 279 |
+
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
| 280 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
| 281 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
| 282 |
+
return predicted_caption
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def forward(self, images, ques, ans):
|
| 286 |
+
|
| 287 |
+
batch_size = ques.size()[0]
|
| 288 |
+
questions = ques.to(self.config.get("device"))
|
| 289 |
+
answers = ans.to(self.config.get("device"))
|
| 290 |
+
target_length = ans.size()[1]
|
| 291 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
| 292 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
| 293 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
| 294 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
| 295 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
| 296 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
| 297 |
+
|
| 298 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(questions)
|
| 299 |
+
answers_embed = self.phi2_model.model.model.embed_tokens(answers)
|
| 300 |
+
|
| 301 |
+
are_all_zeros = torch.all(images == 0).item()
|
| 302 |
+
if are_all_zeros:
|
| 303 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
| 304 |
+
else:
|
| 305 |
+
images = {'pixel_values': images.to(self.config.get("device"))}
|
| 306 |
+
clip_outputs = self.clip_model(**images)
|
| 307 |
+
images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
|
| 308 |
+
|
| 309 |
+
# projection
|
| 310 |
+
image_embeds = self.projection_layer(images_embeds).to(torch.float16)
|
| 311 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
| 312 |
+
|
| 313 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 314 |
+
# # for loss
|
| 315 |
+
loss = 0
|
| 316 |
+
for pos in range(target_length - 1):
|
| 317 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 318 |
+
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
| 319 |
+
loss += pos_loss
|
| 320 |
+
loss = loss / target_length
|
| 321 |
+
|
| 322 |
+
# Delete variables to free up memory
|
| 323 |
+
del combined_embeds
|
| 324 |
+
del model_output_logits
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
+
return loss
|
| 327 |
+
|
| 328 |
+
def validate_model_phase2(model, val_dataloader, tokenizer, config):
|
| 329 |
+
model.eval()
|
| 330 |
+
total_loss = 0
|
| 331 |
+
with torch.no_grad():
|
| 332 |
+
# try:
|
| 333 |
+
for images, ques, ans in tqdm(val_dataloader):
|
| 334 |
+
loss = model(images, ques, ans)
|
| 335 |
+
total_loss+=loss.item()
|
| 336 |
+
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
| 337 |
+
# except Exception as e:
|
| 338 |
+
# pass
|
| 339 |
+
model.train()
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
|
| 343 |
+
phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
|
| 344 |
+
proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
|
| 345 |
+
model.phi2_model.train()
|
| 346 |
+
model.projection_layer.train()
|
| 347 |
+
|
| 348 |
+
pbar = tqdm(train_loader)
|
| 349 |
+
for epoch in range(1, config.get("epochs")):
|
| 350 |
+
print(f"Epoch: {epoch}")
|
| 351 |
+
torch.cuda.empty_cache()
|
| 352 |
+
step = 1
|
| 353 |
+
try:
|
| 354 |
+
for idx, (images, ques, ans) in enumerate(pbar):
|
| 355 |
+
try:
|
| 356 |
+
phi2_optim.zero_grad()
|
| 357 |
+
proj_optim.zero_grad()
|
| 358 |
+
loss = model(images, ques, ans)
|
| 359 |
+
loss.backward()
|
| 360 |
+
phi2_optim.step()
|
| 361 |
+
proj_optim.step()
|
| 362 |
+
pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
|
| 363 |
+
torch.cuda.empty_cache()
|
| 364 |
+
step+=1
|
| 365 |
+
if (step%1000==0):
|
| 366 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
| 367 |
+
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
| 368 |
+
except Exception as e:
|
| 369 |
+
print("in frp",e)
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
validate_model_phase2(model, val_dataloader, tokenizer, config)
|
| 373 |
+
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
| 374 |
+
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
print(e)
|
| 378 |
+
continue
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
git+https://github.com/huggingface/peft.git
|
| 4 |
+
accelerate
|
| 5 |
+
transformers
|
| 6 |
+
einops
|
| 7 |
+
git+https://github.com/m-bain/whisperx.git
|