MLInAi commited on
Commit
b59ed0f
·
verified ·
1 Parent(s): e8d2c0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -12,6 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  # Fixed prompt
13
  FIXED_PROMPT = "Generate a funny caption from image"
14
 
 
15
  # Function to generate caption for the uploaded image with the fixed prompt
16
  def generate_caption(image):
17
  # Preprocess the image
@@ -27,13 +28,20 @@ def generate_caption(image):
27
 
28
  # Prepend the fixed prompt to the input tensor
29
  fixed_prompt_tensor = tokenizer(FIXED_PROMPT, return_tensors="pt").input_ids.to(device)
30
- input_tensor = torch.cat((fixed_prompt_tensor, image_tensor), dim=1)
 
 
 
 
 
 
31
 
32
  # Generate caption
33
  output = model.generate(pixel_values=image_tensor)
34
  caption = tokenizer.decode(output[0], skip_special_tokens=True)
35
  return caption
36
 
 
37
  # Streamlit app
38
  st.title("Image Caption Generator")
39
 
 
12
  # Fixed prompt
13
  FIXED_PROMPT = "Generate a funny caption from image"
14
 
15
+ # Function to generate caption for the uploaded image with the fixed prompt
16
  # Function to generate caption for the uploaded image with the fixed prompt
17
  def generate_caption(image):
18
  # Preprocess the image
 
28
 
29
  # Prepend the fixed prompt to the input tensor
30
  fixed_prompt_tensor = tokenizer(FIXED_PROMPT, return_tensors="pt").input_ids.to(device)
31
+
32
+ # Repeat the prompt tensor to match the batch size of the image tensor
33
+ batch_size = image_tensor.shape[0]
34
+ repeated_prompt_tensor = fixed_prompt_tensor.repeat(batch_size, 1)
35
+
36
+ # Concatenate the prompt tensor with the image tensor along the sequence dimension
37
+ input_tensor = torch.cat((repeated_prompt_tensor, image_tensor), dim=1)
38
 
39
  # Generate caption
40
  output = model.generate(pixel_values=image_tensor)
41
  caption = tokenizer.decode(output[0], skip_special_tokens=True)
42
  return caption
43
 
44
+
45
  # Streamlit app
46
  st.title("Image Caption Generator")
47