nroggendorff commited on
Commit
9936b41
·
verified ·
1 Parent(s): ae0c2b0

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +4 -12
train.py CHANGED
@@ -36,14 +36,10 @@ def caption_batch(batch, processor, model):
36
 
37
  pil_images = []
38
  for image in images:
39
- if isinstance(image, torch.Tensor):
40
- image = image.cpu().numpy()
41
-
42
- if not isinstance(image, Image.Image):
43
- image = Image.fromarray(image)
44
- if image.mode != "RGB":
45
- image = image.convert("RGB")
46
- pil_images.append(image)
47
 
48
  msg = [
49
  {
@@ -110,9 +106,6 @@ def process_shard_worker(
110
  else:
111
  shard = cast(Dataset, loaded)
112
 
113
- shard = shard.with_format("torch")
114
- shard.set_format(type="torch", columns=["image"])
115
-
116
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
117
  result = shard.map(
118
  lambda batch: caption_batch(batch, processor, model),
@@ -120,7 +113,6 @@ def process_shard_worker(
120
  batch_size=batch_size,
121
  remove_columns=[col for col in shard.column_names if col != "image"],
122
  writer_batch_size=1000,
123
- keep_in_memory=True,
124
  )
125
 
126
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
 
36
 
37
  pil_images = []
38
  for image in images:
39
+ if isinstance(image, Image.Image):
40
+ if image.mode != "RGB":
41
+ image = image.convert("RGB")
42
+ pil_images.append(image)
 
 
 
 
43
 
44
  msg = [
45
  {
 
106
  else:
107
  shard = cast(Dataset, loaded)
108
 
 
 
 
109
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
110
  result = shard.map(
111
  lambda batch: caption_batch(batch, processor, model),
 
113
  batch_size=batch_size,
114
  remove_columns=[col for col in shard.column_names if col != "image"],
115
  writer_batch_size=1000,
 
116
  )
117
 
118
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)