MLInAi commited on
Commit
c17b215
·
verified ·
1 Parent(s): 3120115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -12,7 +12,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  # Fixed prompt
13
  FIXED_PROMPT = "Generate a funny caption from image"
14
 
15
- # Function to generate caption for the uploaded image with the fixed prompt
16
  # Function to generate caption for the uploaded image with the fixed prompt
17
  def generate_caption(image):
18
  # Preprocess the image
@@ -33,7 +32,11 @@ def generate_caption(image):
33
  batch_size = image_tensor.shape[0]
34
  repeated_prompt_tensor = fixed_prompt_tensor.repeat(batch_size, 1)
35
 
36
- # Concatenate the prompt tensor with the image tensor along the sequence dimension
 
 
 
 
37
  input_tensor = torch.cat((repeated_prompt_tensor, image_tensor), dim=1)
38
 
39
  # Generate caption
@@ -42,6 +45,7 @@ def generate_caption(image):
42
  return caption
43
 
44
 
 
45
  # Streamlit app
46
  st.title("Cartoon Caption Generator")
47
 
 
12
  # Fixed prompt
13
  FIXED_PROMPT = "Generate a funny caption from image"
14
 
 
15
  # Function to generate caption for the uploaded image with the fixed prompt
16
  def generate_caption(image):
17
  # Preprocess the image
 
32
  batch_size = image_tensor.shape[0]
33
  repeated_prompt_tensor = fixed_prompt_tensor.repeat(batch_size, 1)
34
 
35
+ # Reshape the image tensor to match the shape of the prompt tensor
36
+ # The reshaping depends on the model's input requirements
37
+ image_tensor = image_tensor.view(batch_size, -1)
38
+
39
+ # Concatenate the prompt tensor with the reshaped image tensor along the sequence dimension
40
  input_tensor = torch.cat((repeated_prompt_tensor, image_tensor), dim=1)
41
 
42
  # Generate caption
 
45
  return caption
46
 
47
 
48
+
49
  # Streamlit app
50
  st.title("Cartoon Caption Generator")
51