| import os |
| import deepspeed |
| |
| from tqdm import tqdm |
| import shutil |
| os.environ['HF_ENDPOINT']="https://hf-mirror.com" |
| from qwenva import tokenizer |
| from qwenva import processor |
| from qwenva import qwenva |
| images_file_path='./data/download/llava-v1.5-instruct/coco/train2017' |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import os |
| import json |
| from PIL import Image |
| import json |
| with open('/root/autodl-tmp/LLaVA-Instruct-150K/llava_instruct_150k.json', 'r', encoding='utf-8') as f: |
| chat_data = json.load(f) |
| import torch |
| image_token=tokenizer.encode('<image>')[0] |
| pad_token=tokenizer.pad_token_id |
| image_token=tokenizer.encode('<image>')[0] |
| pad_token=tokenizer.pad_token_id |
| def process_data(sample,max_len=8012): |
| conversations=sample['conversations'] |
| labels=[] |
| input_ids=[] |
| flag=0 |
| messages=[] |
| input_ids=[] |
| try: |
| for index,item in enumerate(conversations): |
| if item['from']=='human': |
| old_input_ids=input_ids |
| messages.append({'role':'user','content':item['value']}) |
| input_ids=tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True |
| ) |
| |
| labels+=[-100]*(len(input_ids)-len(old_input_ids)) |
| if index==flag: |
| try: |
| image_index=input_ids.index(image_token) |
| labels[image_index]=image_token |
| except ValueError: |
| print("image token not found") |
| flag=index+1 |
| continue |
| elif item['from']=='gpt': |
| old_input_ids=input_ids |
| messages.append({'role':'assistant','content':item['value']}) |
| input_ids=tokenizer.apply_chat_template( |
| messages |
| ) |
| labels+=input_ids[len(old_input_ids):] |
| except: |
| print("error in process_data_1") |
| exit() |
| |
| try: |
| if len(input_ids)>max_len: |
| input_ids=input_ids[:max_len] |
| labels=labels[:max_len] |
| attention_mask=[1]*len(input_ids) |
| else: |
| attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids)) |
| input_ids+=[pad_token]*(max_len-len(input_ids)) |
| labels+=[-100]*(max_len-len(labels)) |
| except: |
| print("error in process_data_2") |
| exit() |
| |
| try: |
| input_ids=torch.tensor(input_ids) |
| attention_mask=torch.tensor(attention_mask) |
| labels=torch.tensor(labels) |
| image_index=torch.tensor(image_index) |
| except: |
| print("error in tensor") |
| exit() |
| return { |
| 'input_ids':input_ids, |
| 'attention_mask':attention_mask, |
| 'labels':labels, |
| 'image_idx':image_index |
| } |
|
|
| |
| import os |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
| class MyDataset(Dataset): |
| def __init__(self, images_file_path,data,max_len=1024): |
| self.max_len=max_len |
| self.images_file_path = images_file_path |
| self.data = data |
| self.max_len=max_len |
| def __len__(self): |
| return len(self.data) |
| def __getitem__(self, index): |
| output_=process_data(self.data[index],max_len=self.max_len) |
| img_path=os.path.join(self.images_file_path,self.data[index]['image']) |
| try: |
| img=Image.open(img_path) |
| except: |
| print(f"image {img_path} not found") |
| output_['labels']=torch.tensor([-100]*self.max_len) |
| input_pixel= processor(images=img, return_tensors="pt") |
| output_['input_pixel']=input_pixel['pixel_values'].squeeze() |
| return output_ |
| |
|
|
| |
| dataset=MyDataset(images_file_path,chat_data,max_len=2048) |
| train_loader=DataLoader(dataset,batch_size=8,shuffle=True) |
| import argparse |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| qwenva=qwenva.to(device) |
| model_engine,optimizer,_,_=deepspeed.initialize( |
| model=qwenva, |
| args=argparse.Namespace(), |
| model_parameters=qwenva.parameters(), |
| config_params="./deepspeed_config.json" |
| ) |
| |
| |
| |
| |
| for name, param in model_engine.module._orig_mod.text_embedding.named_parameters(): |
| param.requires_grad = True |
| |
| |
| |
| |
|
|
| for name,param in model_engine.module._orig_mod.lm_head.named_parameters(): |
| param.requires_grad = True |
| |
|
|
| for name,param in model_engine.module._orig_mod.transformer.named_parameters(): |
| param.requires_grad = True |
| |
| for name,param in model_engine.module._orig_mod.named_parameters(): |
| if param.requires_grad: |
| print(f"Layer: {name}, Requires Grad: {param.requires_grad}") |
| |
|
|
| |
| import torch.nn as nn |
| loss_fn = nn.CrossEntropyLoss() |
| |
| accumulation_steps = 1 |
| |
| def train(model_engine, train_dataloader, loss_fn, device, epochs): |
| model_engine.train() |
| |
| for epoch in range(epochs): |
| |
| with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar: |
| |
| try: |
| for batch_idx, batch in enumerate(train_dataloader): |
| |
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| input_pixel = batch['input_pixel'].to(device) |
| labels = batch['labels'].to(device) |
| image_idx=batch['image_idx'].to(device) |
| logits = model_engine(input_ids, attention_mask, input_pixel,image_idx) |
| |
| max_logits= logits.max(dim=-1, keepdim=True)[0] |
| stable_logits= logits - max_logits |
| loss= loss_fn(stable_logits[:, :-1, :].reshape(-1, stable_logits.shape[-1]), labels[:, 1:].reshape(-1).clone()) |
| model_engine.backward(loss) |
| if (batch_idx+1)%accumulation_steps==0: |
| model_engine.step() |
| pbar.update(1) |
| pbar.set_postfix(loss=loss.item()) |
| if (batch_idx+1)%6000==0: |
| |
| if os.path.exists("./best_model_2"): |
| shutil.rmtree("./best_model_2") |
| os.makedirs("./best_model_2") |
| model_engine.save_checkpoint("./best_model_2") |
| torch.save(model_engine.module.state_dict(), "./compiled_model_2.pth") |
| print(f" model saved at batch {batch_idx+1}") |
| except Exception as e: |
| print(f"error in train {e}") |
| |
| train(model_engine, train_loader, loss_fn, device, epochs=2) |
| |