Spaces:
Build error
Build error
| import glob | |
| import pandas as pd | |
| from PIL import Image | |
| from torch.utils.data import Dataset, random_split | |
| from transformers import TrainingArguments, Trainer, ViTFeatureExtractor, BertTokenizer, VisionEncoderDecoderModel | |
| import torch | |
| import gc | |
| import os | |
| torch.manual_seed(42) | |
| from pathlib import Path | |
| # I'm on Linux so you need to convert back to Windows | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| path = '/media/delta/S/Photos/Photo_Data' | |
| feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "bert-base-uncased").to(device) | |
| model.config.decoder_start_token_id = tokenizer.cls_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| list_of_csv = glob.glob(f'{path}/*.csv') # to change | |
| DF = [] | |
| for f in list_of_csv: | |
| df = pd.read_csv(f) | |
| DF.append(df) | |
| ds = pd.concat(DF) | |
| class CustomDataset(Dataset): | |
| def __init__(self,ds, tokenizer,feature_extractor): | |
| self.Pixel_Values = [] | |
| self.Labels = [] | |
| for i,r in ds.iterrows(): | |
| image_path = r['IMAGEPATH'] #A table in csv format with 2 columns IMAGEPATH and CAPTION | |
| labels = r['CAPTION'] | |
| labels = str(labels) | |
| if len(image_path) >=10 and len(labels)>=10: | |
| image_path = image_path.split('\\') | |
| image_path = image_path[-3:] | |
| image_path = Path(os.getcwd(),image_path[0],image_path[1],image_path[2]) | |
| image = Image.open(str(image_path)).convert("RGB") | |
| pixel_values = feature_extractor(image, return_tensors="pt").pixel_values | |
| self.Pixel_Values.append(pixel_values) | |
| labels = tokenizer(labels,return_tensors="pt", truncation=True, max_length=128, padding="max_length").input_ids | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| self.Labels.append(labels) | |
| def __len__(self): | |
| return len(self.Pixel_Values) | |
| def __getitem__(self, idx): | |
| return {"pixel_values": self.Pixel_Values[idx], "labels": self.Labels[idx]} | |
| dataset = CustomDataset(ds,tokenizer,feature_extractor) | |
| train_size = int(0.9 * len(dataset)) | |
| train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size]) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| training_args = TrainingArguments(output_dir=str(Path(os.getcwd(),'results')), | |
| num_train_epochs=6, | |
| logging_steps=300, | |
| save_steps=14770, | |
| per_device_train_batch_size=16, | |
| per_device_eval_batch_size=16, | |
| gradient_accumulation_steps=1, | |
| gradient_checkpointing=False, | |
| fp16=False, #doesnt work for this model | |
| optim="adamw_torch", #change to adamw_torch if you have have enough memory['adamw_hf', 'adamw_torch', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'sgd', 'adagrad'] | |
| warmup_steps=1, | |
| weight_decay=0.05, | |
| logging_dir='/home/delta/Downloads/logs', # loss graph | |
| report_to = 'tensorboard', | |
| ) | |
| def collate_fn(examples): | |
| pixel_values = torch.stack([example["pixel_values"][0] for example in examples]) #0 to change from [1,3,224,224] to [3,224,224] torch stack will add it back depends on the batch size, | |
| labels = torch.stack([example["labels"][0] for example in examples]) | |
| return {"pixel_values": pixel_values, "labels": labels} | |
| Trainer(model=model, args=training_args, train_dataset=train_dataset, | |
| eval_dataset=val_dataset, data_collator=collate_fn).train() | |
| model.save_pretrained('/media/delta/S/model_caption') | |
| tokenizer.save_pretrained('/media/delta/S/tokenizer_caption') | |
| feature_extractor.save_pretrained('/media/delta/S/feature_extractor_caption') | |