MLInAi commited on
Commit
5a8c443
·
verified ·
1 Parent(s): 5252660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -33
app.py CHANGED
@@ -1,39 +1,18 @@
1
- # import os
2
-
3
- # os.system('pip install --upgrade pip')
4
-
5
  import streamlit as st
6
  from transformers import BlipForConditionalGeneration, AutoTokenizer
7
- import torchvision.transforms as transforms
8
  import torch
9
  from PIL import Image
10
- from io import BytesIO
11
-
12
- # # Load the fine-tuned model
13
- # model_path = '/content/model_after_5_epochs.pth'
14
- # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
15
- # model.load_state_dict(torch.load(model_path))
16
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- # model.to(device)
18
- # tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip-image-captioning-base")
19
-
20
 
 
21
  model = BlipForConditionalGeneration.from_pretrained("MLInAi/CartoonCaptionGen")
22
  tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
 
23
 
24
- # Function to generate caption for the uploaded image
25
- # def generate_caption(image):
26
- # # Preprocess the image
27
- # image = Image.open(image).convert("RGB")
28
- # image = image.resize((224, 224)) # Resize the image to match model input size
29
- # image_tensor = torch.tensor([torch.Tensor(image)]).permute(0, 3, 1, 2).to(device)
30
-
31
- # # Generate caption
32
- # output = model.generate(pixel_values=image_tensor)
33
- # caption = tokenizer.decode(output[0], skip_special_tokens=True)
34
- # return caption
35
-
36
 
 
37
  def generate_caption(image):
38
  # Preprocess the image
39
  image = Image.open(image).convert("RGB")
@@ -44,13 +23,14 @@ def generate_caption(image):
44
  transforms.ToTensor(),
45
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
46
  ])
47
- image_tensor = transform(image).unsqueeze(0)
48
-
49
- # Generate caption
50
- output = model.generate(pixel_values=image_tensor)
 
51
  caption = tokenizer.decode(output[0], skip_special_tokens=True)
52
  return caption
53
-
54
  # Streamlit app
55
  st.title("Image Caption Generator")
56
 
@@ -61,6 +41,6 @@ if uploaded_image is not None:
61
  st.write("")
62
  st.write("Generating caption...")
63
 
64
- # Generate caption for the uploaded image
65
  caption = generate_caption(uploaded_image)
66
  st.write("Caption:", caption)
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import BlipForConditionalGeneration, AutoTokenizer
 
3
  import torch
4
  from PIL import Image
5
+ import torchvision.transforms as transforms
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load the fine-tuned model and tokenizer
8
  model = BlipForConditionalGeneration.from_pretrained("MLInAi/CartoonCaptionGen")
9
  tokenizer = AutoTokenizer.from_pretrained("MLInAi/CartoonCaptionGen")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Fixed prompt
13
+ FIXED_PROMPT = "Generate a funny caption from image"
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Function to generate caption for the uploaded image with the fixed prompt
16
  def generate_caption(image):
17
  # Preprocess the image
18
  image = Image.open(image).convert("RGB")
 
23
  transforms.ToTensor(),
24
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
  ])
26
+ image_tensor = transform(image).unsqueeze(0).to(device)
27
+
28
+ # Generate caption with the fixed prompt
29
+ input_text = FIXED_PROMPT + " "
30
+ output = model.generate(pixel_values=image_tensor, input_text=input_text)
31
  caption = tokenizer.decode(output[0], skip_special_tokens=True)
32
  return caption
33
+
34
  # Streamlit app
35
  st.title("Image Caption Generator")
36
 
 
41
  st.write("")
42
  st.write("Generating caption...")
43
 
44
+ # Generate caption for the uploaded image with the fixed prompt
45
  caption = generate_caption(uploaded_image)
46
  st.write("Caption:", caption)