nroggendorff commited on
Commit
a04116c
·
verified ·
1 Parent(s): b5d09ec

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +9 -10
train.py CHANGED
@@ -40,34 +40,33 @@ def caption_batch(batch, processor, model):
40
  image = image.convert("RGB")
41
  pil_images.append(image)
42
 
43
- text_inputs = []
44
- for _ in pil_images:
45
  msg = [
46
  {
47
  "role": "user",
48
  "content": [
 
49
  {"type": "text", "text": "Describe the image, and skip mentioning that it's illustrated or from anime."},
50
  ],
51
  }
52
  ]
53
- text_inputs.append(processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
54
 
 
 
55
  inputs = processor(
56
- text=text_inputs,
57
  images=pil_images,
58
  return_tensors="pt",
59
  padding=True
60
  )
61
 
62
- input_ids = inputs.input_ids.to(model.device)
63
- attention_mask = inputs.attention_mask.to(model.device)
64
- pixel_values = inputs.pixel_values.to(model.device)
65
 
66
  with torch.no_grad():
67
  generated = model.generate(
68
- input_ids=input_ids,
69
- attention_mask=attention_mask,
70
- pixel_values=pixel_values,
71
  max_new_tokens=256,
72
  )
73
 
 
40
  image = image.convert("RGB")
41
  pil_images.append(image)
42
 
43
+ messages_list = []
44
+ for pil_image in pil_images:
45
  msg = [
46
  {
47
  "role": "user",
48
  "content": [
49
+ {"type": "image"},
50
  {"type": "text", "text": "Describe the image, and skip mentioning that it's illustrated or from anime."},
51
  ],
52
  }
53
  ]
54
+ messages_list.append(msg)
55
 
56
+ texts = processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
57
+
58
  inputs = processor(
59
+ text=texts,
60
  images=pil_images,
61
  return_tensors="pt",
62
  padding=True
63
  )
64
 
65
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
66
 
67
  with torch.no_grad():
68
  generated = model.generate(
69
+ **inputs,
 
 
70
  max_new_tokens=256,
71
  )
72