haris018 commited on
Commit
c56ecdd
·
verified ·
1 Parent(s): f0c6892

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
5
+
6
+ # Use st.cache_resource to load the model and processor once.
7
+ # This saves time and memory when the app re-runs.
8
+ @st.cache_resource
9
+ def load_blip_model():
10
+ """
11
+ Loads the BLIP-2 model and processor from Hugging Face.
12
+
13
+ Returns:
14
+ tuple: The loaded processor and model.
15
+ """
16
+ # Use the appropriate BLIP-2 model. "Salesforce/blip2-opt-2.7b" is a good option.
17
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
18
+ # Using device_map="auto" to automatically handle model placement on GPU/CPU.
19
+ model = Blip2ForConditionalGeneration.from_pretrained(
20
+ "Salesforce/blip2-opt-2.7b",
21
+ device_map="auto",
22
+ torch_dtype=torch.float16 # Use float16 for reduced memory usage
23
+ )
24
+ return processor, model
25
+
26
+ # Load the model and processor
27
+ processor, model = load_blip_model()
28
+
29
+ # Set up the Streamlit app layout and title
30
+ st.set_page_config(
31
+ page_title="BLIP-2 Image Captioning",
32
+ page_icon="📸",
33
+ layout="centered"
34
+ )
35
+
36
+ st.title("📸 BLIP-2 Image Captioning")
37
+ st.markdown("### Generate captions for your images using a powerful vision-language model.")
38
+ st.markdown("---")
39
+
40
+ # File uploader widget for the user to upload an image
41
+ uploaded_file = st.file_uploader(
42
+ "Upload an image",
43
+ type=["jpg", "jpeg", "png", "webp"],
44
+ help="Drag and drop or click to upload your image."
45
+ )
46
+
47
+ if uploaded_file is not None:
48
+ try:
49
+ # Open the uploaded image
50
+ image = Image.open(uploaded_file).convert('RGB')
51
+
52
+ # Display the uploaded image
53
+ st.image(image, caption="Uploaded Image", use_column_width=True, channels="RGB")
54
+
55
+ # Create a button to generate the caption
56
+ if st.button("Generate Caption"):
57
+ with st.spinner("Generating caption..."):
58
+ # Preprocess the image and generate input tensors
59
+ inputs = processor(images=image, return_tensors="pt").to(model.device, torch.float16)
60
+
61
+ # Generate a caption using the model
62
+ outputs = model.generate(**inputs, max_length=50)
63
+
64
+ # Decode the generated caption tokens to a string
65
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ # Display the generated caption
68
+ st.success("Caption generated!")
69
+ st.markdown(f"### **Generated Caption:**")
70
+ st.info(caption.capitalize())
71
+
72
+ except Exception as e:
73
+ st.error(f"An error occurred: {e}")
74
+ st.markdown("Please try uploading a different image or check the model availability.")
75
+
76
+ else:
77
+ st.info("Upload an image to get started!")