nroggendorff commited on
Commit
bb6501e
·
verified ·
1 Parent(s): 83184b9

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +22 -10
train.py CHANGED
@@ -4,6 +4,9 @@ import datasets
4
  from datasets import Dataset
5
  from typing import cast
6
  import os
 
 
 
7
 
8
 
9
  def load_model(model_name, device_id=0):
@@ -29,8 +32,16 @@ def load_model(model_name, device_id=0):
29
  def caption_batch(batch, processor, model):
30
  images = batch["image"]
31
 
32
- encoded_list = []
33
  for image in images:
 
 
 
 
 
 
 
 
34
  msg = [
35
  {
36
  "role": "user",
@@ -79,11 +90,16 @@ def caption_batch(batch, processor, model):
79
  for d in decoded:
80
  if "<|im_start|>assistant" in d:
81
  d = d.split("<|im_start|>assistant")[-1].strip()
82
- d = d.replace("<|im_end|>", "").strip()
 
 
 
 
 
83
  captions.append(d)
84
 
85
  return {
86
- "image": images,
87
  "text": captions,
88
  }
89
 
@@ -120,12 +136,10 @@ def process_shard_worker(
120
 
121
 
122
  def main():
123
- import multiprocessing as mp
124
-
125
- input_dataset = "none-yet/anime-captions"
126
- output_dataset = "none-yet/anime-captions"
127
  model_name = "datalab-to/chandra"
128
- batch_size = 32
129
 
130
  print("Loading dataset info...")
131
  loaded = datasets.load_dataset(input_dataset, split="train")
@@ -174,8 +188,6 @@ def main():
174
  final_ds.push_to_hub(output_dataset, create_pr=True)
175
 
176
  print("Cleaning up temporary files...")
177
- import shutil
178
-
179
  for f in temp_files:
180
  if os.path.exists(f):
181
  shutil.rmtree(f)
 
4
  from datasets import Dataset
5
  from typing import cast
6
  import os
7
+ import shutil
8
+ import multiprocessing as mp
9
+ from PIL import Image
10
 
11
 
12
  def load_model(model_name, device_id=0):
 
32
  def caption_batch(batch, processor, model):
33
  images = batch["image"]
34
 
35
+ processed_images = []
36
  for image in images:
37
+ if not isinstance(image, Image.Image):
38
+ image = Image.fromarray(image)
39
+ if image.mode != "RGB":
40
+ image = image.convert("RGB")
41
+ processed_images.append(image)
42
+
43
+ encoded_list = []
44
+ for image in processed_images:
45
  msg = [
46
  {
47
  "role": "user",
 
90
  for d in decoded:
91
  if "<|im_start|>assistant" in d:
92
  d = d.split("<|im_start|>assistant")[-1].strip()
93
+
94
+ special_tokens = set(processor.tokenizer.all_special_tokens)
95
+ for token in special_tokens:
96
+ d = d.replace(token, "")
97
+
98
+ d = d.strip()
99
  captions.append(d)
100
 
101
  return {
102
+ "image": processed_images,
103
  "text": captions,
104
  }
105
 
 
136
 
137
 
138
  def main():
139
+ input_dataset = "nroggendorff/fries"
140
+ output_dataset = "nroggendorff/fries"
 
 
141
  model_name = "datalab-to/chandra"
142
+ batch_size = 2
143
 
144
  print("Loading dataset info...")
145
  loaded = datasets.load_dataset(input_dataset, split="train")
 
188
  final_ds.push_to_hub(output_dataset, create_pr=True)
189
 
190
  print("Cleaning up temporary files...")
 
 
191
  for f in temp_files:
192
  if os.path.exists(f):
193
  shutil.rmtree(f)