Prashanthsrn commited on
Commit
5d238a1
·
verified ·
1 Parent(s): 1e8d9e4

Update caption_refiner.py

Browse files
Files changed (1) hide show
  1. caption_refiner.py +4 -2
caption_refiner.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import T5Tokenizer, T5ForConditionalGeneration
 
2
 
3
  # Load the T5 model and tokenizer
4
  tokenizer = T5Tokenizer.from_pretrained("t5-base")
@@ -7,10 +8,11 @@ model = T5ForConditionalGeneration.from_pretrained("t5-base")
7
  def refine_caption(initial_caption):
8
  # Prepare the input for the T5 model
9
  input_text = f"refine caption: {initial_caption}"
10
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids
11
 
12
  # Generate the refined caption
13
- outputs = model.generate(input_ids, max_new_tokens=100, num_return_sequences=1)
 
14
 
15
  # Decode the refined caption
16
  refined_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
  from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+ import torch
3
 
4
  # Load the T5 model and tokenizer
5
  tokenizer = T5Tokenizer.from_pretrained("t5-base")
 
8
  def refine_caption(initial_caption):
9
  # Prepare the input for the T5 model
10
  input_text = f"refine caption: {initial_caption}"
11
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
12
 
13
  # Generate the refined caption
14
+ with torch.no_grad():
15
+ outputs = model.generate(inputs.input_ids, max_new_tokens=100, num_return_sequences=1)
16
 
17
  # Decode the refined caption
18
  refined_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)