Prashanthsrn commited on
Commit
ef06464
·
verified ·
1 Parent(s): 18ac5a7

Delete caption_refiner.py

Browse files
Files changed (1) hide show
  1. caption_refiner.py +0 -20
caption_refiner.py DELETED
@@ -1,20 +0,0 @@
1
- from transformers import T5Tokenizer, T5ForConditionalGeneration
2
- import torch
3
-
4
- # Load the T5 model and tokenizer
5
- tokenizer = T5Tokenizer.from_pretrained("t5-base")
6
- model = T5ForConditionalGeneration.from_pretrained("t5-base")
7
-
8
- def refine_caption(initial_caption):
9
- # Prepare the input for the T5 model
10
- input_text = f"refine caption: {initial_caption}"
11
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
12
-
13
- # Generate the refined caption
14
- with torch.no_grad():
15
- outputs = model.generate(inputs.input_ids, max_new_tokens=100, num_return_sequences=1)
16
-
17
- # Decode the refined caption
18
- refined_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
-
20
- return refined_caption