Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,10 +9,7 @@ model = BlipForConditionalGeneration.from_pretrained("MLInAi/CartoonCaptionGen")
|
|
| 9 |
tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
FIXED_PROMPT = "Generate a funny caption from image"
|
| 14 |
-
|
| 15 |
-
# Function to generate caption for the uploaded image with the fixed prompt
|
| 16 |
def generate_caption(image):
|
| 17 |
# Preprocess the image
|
| 18 |
image = Image.open(image).convert("RGB")
|
|
@@ -25,20 +22,6 @@ def generate_caption(image):
|
|
| 25 |
])
|
| 26 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 27 |
|
| 28 |
-
# Prepend the fixed prompt to the input tensor
|
| 29 |
-
fixed_prompt_tensor = tokenizer(FIXED_PROMPT, return_tensors="pt").input_ids.to(device)
|
| 30 |
-
|
| 31 |
-
# Repeat the prompt tensor to match the batch size of the image tensor
|
| 32 |
-
batch_size = image_tensor.shape[0]
|
| 33 |
-
repeated_prompt_tensor = fixed_prompt_tensor.repeat(batch_size, 1)
|
| 34 |
-
|
| 35 |
-
# Reshape the image tensor to match the shape of the prompt tensor
|
| 36 |
-
# The reshaping depends on the model's input requirements
|
| 37 |
-
image_tensor = image_tensor.view(batch_size, -1)
|
| 38 |
-
|
| 39 |
-
# Concatenate the prompt tensor with the reshaped image tensor along the sequence dimension
|
| 40 |
-
input_tensor = torch.cat((repeated_prompt_tensor, image_tensor), dim=1)
|
| 41 |
-
|
| 42 |
# Generate caption
|
| 43 |
output = model.generate(pixel_values=image_tensor)
|
| 44 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
|
| 9 |
tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
|
| 10 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
|
| 12 |
+
# Function to generate caption for the uploaded image
|
|
|
|
|
|
|
|
|
|
| 13 |
def generate_caption(image):
|
| 14 |
# Preprocess the image
|
| 15 |
image = Image.open(image).convert("RGB")
|
|
|
|
| 22 |
])
|
| 23 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Generate caption
|
| 26 |
output = model.generate(pixel_values=image_tensor)
|
| 27 |
caption = tokenizer.decode(output[0], skip_special_tokens=True)
|