nroggendorff commited on
Commit
b5d09ec
·
verified ·
1 Parent(s): bb6501e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -33
train.py CHANGED
@@ -31,56 +31,43 @@ def load_model(model_name, device_id=0):
31
 
32
  def caption_batch(batch, processor, model):
33
  images = batch["image"]
34
-
35
- processed_images = []
36
  for image in images:
37
  if not isinstance(image, Image.Image):
38
  image = Image.fromarray(image)
39
  if image.mode != "RGB":
40
  image = image.convert("RGB")
41
- processed_images.append(image)
42
 
43
- encoded_list = []
44
- for image in processed_images:
45
  msg = [
46
  {
47
  "role": "user",
48
  "content": [
49
- {"type": "image", "image": image},
50
- {
51
- "type": "text",
52
- "text": "Describe the image, and skip mentioning that it's illustrated or from anime.",
53
- },
54
  ],
55
  }
56
  ]
 
57
 
58
- enc = processor.apply_chat_template(
59
- msg,
60
- tokenize=True,
61
- add_generation_prompt=True,
62
- return_dict=True,
63
- return_tensors="pt",
64
- )
65
-
66
- encoded_list.append(enc)
67
-
68
- input_ids = torch.nn.utils.rnn.pad_sequence(
69
- [e.input_ids[0] for e in encoded_list],
70
- batch_first=True,
71
- padding_value=processor.tokenizer.pad_token_id,
72
- ).to(model.device)
73
 
74
- attention_mask = torch.nn.utils.rnn.pad_sequence(
75
- [e.attention_mask[0] for e in encoded_list],
76
- batch_first=True,
77
- padding_value=0,
78
- ).to(model.device)
79
 
80
  with torch.no_grad():
81
  generated = model.generate(
82
  input_ids=input_ids,
83
  attention_mask=attention_mask,
 
84
  max_new_tokens=256,
85
  )
86
 
@@ -90,16 +77,16 @@ def caption_batch(batch, processor, model):
90
  for d in decoded:
91
  if "<|im_start|>assistant" in d:
92
  d = d.split("<|im_start|>assistant")[-1].strip()
93
-
94
  special_tokens = set(processor.tokenizer.all_special_tokens)
95
  for token in special_tokens:
96
  d = d.replace(token, "")
97
-
98
  d = d.strip()
99
  captions.append(d)
100
 
101
  return {
102
- "image": processed_images,
103
  "text": captions,
104
  }
105
 
 
31
 
32
  def caption_batch(batch, processor, model):
33
  images = batch["image"]
34
+
35
+ pil_images = []
36
  for image in images:
37
  if not isinstance(image, Image.Image):
38
  image = Image.fromarray(image)
39
  if image.mode != "RGB":
40
  image = image.convert("RGB")
41
+ pil_images.append(image)
42
 
43
+ text_inputs = []
44
+ for _ in pil_images:
45
  msg = [
46
  {
47
  "role": "user",
48
  "content": [
49
+ {"type": "text", "text": "Describe the image, and skip mentioning that it's illustrated or from anime."},
 
 
 
 
50
  ],
51
  }
52
  ]
53
+ text_inputs.append(processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
54
 
55
+ inputs = processor(
56
+ text=text_inputs,
57
+ images=pil_images,
58
+ return_tensors="pt",
59
+ padding=True
60
+ )
 
 
 
 
 
 
 
 
 
61
 
62
+ input_ids = inputs.input_ids.to(model.device)
63
+ attention_mask = inputs.attention_mask.to(model.device)
64
+ pixel_values = inputs.pixel_values.to(model.device)
 
 
65
 
66
  with torch.no_grad():
67
  generated = model.generate(
68
  input_ids=input_ids,
69
  attention_mask=attention_mask,
70
+ pixel_values=pixel_values,
71
  max_new_tokens=256,
72
  )
73
 
 
77
  for d in decoded:
78
  if "<|im_start|>assistant" in d:
79
  d = d.split("<|im_start|>assistant")[-1].strip()
80
+
81
  special_tokens = set(processor.tokenizer.all_special_tokens)
82
  for token in special_tokens:
83
  d = d.replace(token, "")
84
+
85
  d = d.strip()
86
  captions.append(d)
87
 
88
  return {
89
+ "image": images,
90
  "text": captions,
91
  }
92