nroggendorff commited on
Commit
748af01
·
verified ·
1 Parent(s): 337a414

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +11 -0
train.py CHANGED
@@ -29,7 +29,11 @@ def load_model(model_name, device_id=0):
29
  return processor, model
30
 
31
 
 
 
32
  def caption_batch(batch, processor, model):
 
 
33
  images = batch["image"]
34
 
35
  pil_images = []
@@ -84,6 +88,10 @@ def caption_batch(batch, processor, model):
84
  d = d.strip()
85
  captions.append(d)
86
 
 
 
 
 
87
  return {
88
  "image": images,
89
  "text": captions,
@@ -93,6 +101,9 @@ def caption_batch(batch, processor, model):
93
  def process_shard_worker(
94
  gpu_id, start, end, model_name, batch_size, input_dataset, output_file
95
  ):
 
 
 
96
  torch.cuda.set_device(gpu_id)
97
 
98
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
 
29
  return processor, model
30
 
31
 
32
+ processed_count = 0
33
+
34
  def caption_batch(batch, processor, model):
35
+ global processed_count
36
+
37
  images = batch["image"]
38
 
39
  pil_images = []
 
88
  d = d.strip()
89
  captions.append(d)
90
 
91
+ processed_count += len(images)
92
+ if processed_count > 100:
93
+ print(f"Processed {processed_count} examples so far...")
94
+
95
  return {
96
  "image": images,
97
  "text": captions,
 
101
  def process_shard_worker(
102
  gpu_id, start, end, model_name, batch_size, input_dataset, output_file
103
  ):
104
+ global processed_count
105
+ processed_count = 0
106
+
107
  torch.cuda.set_device(gpu_id)
108
 
109
  print(f"[GPU {gpu_id}] Loading model...", flush=True)