Alpha108 commited on
Commit
508ae61
·
verified ·
1 Parent(s): 8eeea5e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Streamlit BLIP-2 Image Captioning demo
4
+ - Uses HuggingFace transformers' Blip2Processor + Blip2ForConditionalGeneration
5
+ - Caches the model & processor with st.cache_resource so they load once per Space/session.
6
+ - Designed for deployment on Hugging Face Spaces (use Docker SDK / Streamlit template).
7
+ """
8
+
9
+ import streamlit as st
10
+ from PIL import Image
11
+ import io
12
+ import torch
13
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
14
+
15
+ st.set_page_config(
16
+ page_title="BLIP-2 Image Captioning",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded",
19
+ )
20
+
21
+ # --- Sidebar / Info ---
22
+ st.sidebar.title("BLIP-2 Caption Demo")
23
+ st.sidebar.markdown(
24
+ """
25
+ Upload an image and BLIP-2 will generate a caption.
26
+ - Model choices: choose a BLIP-2 model (large models may need GPU / won’t fit on CPU).
27
+ - For Spaces deployment, prefer smaller/flan-xl variants or use inference API.
28
+ """
29
+ )
30
+
31
+ # Recommended default model (change if you want)
32
+ DEFAULT_MODEL = "Salesforce/blip2-opt-2.7b"
33
+
34
+ @st.cache_resource(show_spinner=False)
35
+ def load_model_and_processor(model_name: str):
36
+ """Load and cache the BLIP-2 processor and model."""
37
+ # Note: large models will require a GPU; smaller variants or hosted inference endpoints recommended for CPU-only Spaces.
38
+ processor = Blip2Processor.from_pretrained(model_name)
39
+ model = Blip2ForConditionalGeneration.from_pretrained(model_name)
40
+ # move to GPU if available
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model.to(device)
43
+ return processor, model, device
44
+
45
+ def generate_caption(processor, model, device, pil_image: Image.Image, max_new_tokens=50, num_beams=4):
46
+ """Generate caption text for a PIL image using BLIP-2."""
47
+ if pil_image.mode != "RGB":
48
+ pil_image = pil_image.convert("RGB")
49
+
50
+ inputs = processor(images=pil_image, return_tensors="pt").to(device)
51
+
52
+ # Generate - tune generation args as needed
53
+ generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, num_beams=num_beams)
54
+ # decode using the tokenizer in the processor
55
+ caption = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
56
+ return caption
57
+
58
+ # --- UI layout ---
59
+ col1, col2 = st.columns([1, 1.2])
60
+
61
+ with col1:
62
+ st.header("Upload image")
63
+ uploaded = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"], accept_multiple_files=False)
64
+
65
+ st.markdown("**Model selection**")
66
+ model_name = st.selectbox(
67
+ "Pick BLIP-2 model (large models may not run on CPU)",
68
+ options=[
69
+ "Salesforce/blip2-flan-t5-xl",
70
+ "Salesforce/blip2-opt-2.7b",
71
+ "Salesforce/blip2-flan-t5-xxl",
72
+ ],
73
+ index=1 if DEFAULT_MODEL.endswith("2.7b") else 0,
74
+ help="Large models require GPU or HF Inference API; choose smaller if you have no GPU.",
75
+ )
76
+
77
+ max_tokens = st.slider("Max caption length (tokens)", min_value=10, max_value=200, value=50)
78
+ num_beams = st.slider("Beam search width (num_beams)", min_value=1, max_value=8, value=4)
79
+
80
+ st.write("---")
81
+ st.markdown("Tips:")
82
+ st.markdown(
83
+ "- If deploying on CPU-only Spaces, use a smaller/flan model or use the Hugging Face Inference API.\n"
84
+ "- Model loading is cached to speed up subsequent requests."
85
+ )
86
+
87
+ with col2:
88
+ st.header("Preview & Caption")
89
+ if uploaded is None:
90
+ st.info("Upload an image on the left to generate a caption.")
91
+ st.empty()
92
+ else:
93
+ # display image
94
+ image_bytes = uploaded.read()
95
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
96
+ st.image(pil_image, use_column_width=True)
97
+
98
+ # Load model & processor (cached)
99
+ with st.spinner("Loading model (cached after first load)..."):
100
+ processor, model, device = load_model_and_processor(model_name)
101
+
102
+ # Generate caption
103
+ if st.button("Generate caption"):
104
+ with st.spinner("Generating caption..."):
105
+ try:
106
+ caption = generate_caption(processor, model, device, pil_image, max_new_tokens=max_tokens, num_beams=num_beams)
107
+ st.success("Caption generated")
108
+ st.markdown(f"**Caption:** {caption}")
109
+ # Provide a copy button and simple download
110
+ st.download_button("Download caption (.txt)", caption, file_name="caption.txt")
111
+ except Exception as e:
112
+ st.error(f"Error during generation: {e}")
113
+ st.info("If model is too large or out-of-memory, try a smaller model or use GPU.")
114
+
115
+ # --- Footer / Resources ---
116
+ st.markdown("---")
117
+ st.markdown(
118
+ "Built with BLIP-2 + Transformers. For production or public Spaces hosting, consider using Hugging Face Inference API or a smaller model variant to avoid OOM on CPU-only hosts."
119
+ )
120
+ st.caption("Docs: BLIP-2 (Transformers), Hugging Face Spaces (Streamlit), Streamlit caching & uploader.")