MLInAi commited on
Commit
65aba2a
·
verified ·
1 Parent(s): c17b215

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -18
app.py CHANGED
@@ -9,10 +9,7 @@ model = BlipForConditionalGeneration.from_pretrained("MLInAi/CartoonCaptionGen")
9
  tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
- # Fixed prompt
13
- FIXED_PROMPT = "Generate a funny caption from image"
14
-
15
- # Function to generate caption for the uploaded image with the fixed prompt
16
  def generate_caption(image):
17
  # Preprocess the image
18
  image = Image.open(image).convert("RGB")
@@ -25,20 +22,6 @@ def generate_caption(image):
25
  ])
26
  image_tensor = transform(image).unsqueeze(0).to(device)
27
 
28
- # Prepend the fixed prompt to the input tensor
29
- fixed_prompt_tensor = tokenizer(FIXED_PROMPT, return_tensors="pt").input_ids.to(device)
30
-
31
- # Repeat the prompt tensor to match the batch size of the image tensor
32
- batch_size = image_tensor.shape[0]
33
- repeated_prompt_tensor = fixed_prompt_tensor.repeat(batch_size, 1)
34
-
35
- # Reshape the image tensor to match the shape of the prompt tensor
36
- # The reshaping depends on the model's input requirements
37
- image_tensor = image_tensor.view(batch_size, -1)
38
-
39
- # Concatenate the prompt tensor with the reshaped image tensor along the sequence dimension
40
- input_tensor = torch.cat((repeated_prompt_tensor, image_tensor), dim=1)
41
-
42
  # Generate caption
43
  output = model.generate(pixel_values=image_tensor)
44
  caption = tokenizer.decode(output[0], skip_special_tokens=True)
 
9
  tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Function to generate caption for the uploaded image
 
 
 
13
  def generate_caption(image):
14
  # Preprocess the image
15
  image = Image.open(image).convert("RGB")
 
22
  ])
23
  image_tensor = transform(image).unsqueeze(0).to(device)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Generate caption
26
  output = model.generate(pixel_values=image_tensor)
27
  caption = tokenizer.decode(output[0], skip_special_tokens=True)