moslem commited on
Commit
2ce6e8c
·
verified ·
1 Parent(s): 3d4fa97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -161
app.py CHANGED
@@ -1,161 +1,137 @@
1
- # app.py
2
- """
3
- Image Captioning demo with Gradio + Hugging Face transformers.
4
-
5
- Environment variables:
6
- MODEL_ID - huggingface model id (default: Salesforce/blip-image-captioning-base)
7
- TRUST_REMOTE_CODE - "true"/"false" to allow custom repo code (default: false)
8
- HUGGINGFACE_HUB_TOKEN - optional, if your model is private
9
-
10
- Run:
11
- python app.py
12
- """
13
- import os
14
- import logging
15
- from typing import Optional
16
-
17
- import torch
18
- from PIL import Image
19
- from transformers import pipeline
20
- import gradio as gr
21
-
22
- # ----------------------------
23
- # Configuration & logging
24
- # ----------------------------
25
- MODEL_ID = os.environ.get("MODEL_ID", "Salesforce/blip-image-captioning-base")
26
- TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1", "true", "yes")
27
- HUB_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") # optional (for private models)
28
-
29
- logging.basicConfig(level=logging.INFO)
30
- logger = logging.getLogger("image-caption-gradio")
31
-
32
- # ----------------------------
33
- # Device helper
34
- # ----------------------------
35
- def get_pipeline_device() -> int:
36
- """
37
- Return device index for transformers pipeline:
38
- 0 (GPU) if available, else -1 (CPU)
39
- """
40
- return 0 if torch.cuda.is_available() else -1
41
-
42
-
43
- # ----------------------------
44
- # Load pipeline (global)
45
- # ----------------------------
46
- caption_pipe = None
47
- _load_error: Optional[str] = None
48
-
49
- def load_caption_pipeline():
50
- """
51
- Load the image-to-text pipeline once (global).
52
- Uses HUGGINGFACE_HUB_TOKEN if set for private models.
53
- """
54
- global caption_pipe, _load_error
55
- if caption_pipe is not None or _load_error:
56
- return
57
-
58
- device = get_pipeline_device()
59
- logger.info("Loading model '%s' (trust_remote_code=%s) on device %s", MODEL_ID, TRUST_REMOTE_CODE, device)
60
-
61
- try:
62
- # If HUB_TOKEN is provided, transformers/huggingface_hub will pick it up from env.
63
- caption_pipe = pipeline(
64
- "image-to-text",
65
- model=MODEL_ID,
66
- device=device,
67
- trust_remote_code=TRUST_REMOTE_CODE,
68
- )
69
- logger.info("Model loaded successfully.")
70
- except Exception as e:
71
- _load_error = str(e)
72
- logger.exception("Failed to load model: %s", e)
73
-
74
-
75
- # Preload model at startup (best-effort)
76
- load_caption_pipeline()
77
-
78
-
79
- # ----------------------------
80
- # Inference function used by Gradio
81
- # ----------------------------
82
- def caption_image(img: Image.Image) -> str:
83
- """
84
- Run the captioning pipeline on a PIL image and return the caption text.
85
- """
86
- if _load_error:
87
- # If loading failed earlier, return the error for the UI
88
- return f"Error loading model: {_load_error}"
89
-
90
- if caption_pipe is None:
91
- # Try loading lazily if not loaded yet
92
- load_caption_pipeline()
93
- if caption_pipe is None:
94
- return f"Model not loaded. Try again in a moment."
95
-
96
- try:
97
- outputs = caption_pipe(img)
98
- # pipeline usually returns a list of dicts with 'generated_text'
99
- if isinstance(outputs, list) and outputs:
100
- caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
101
- else:
102
- caption = str(outputs)
103
- return caption.strip()
104
- except Exception as e:
105
- logger.exception("Captioning error: %s", e)
106
- return f"Captioning failed: {e}"
107
-
108
-
109
- # ----------------------------
110
- # Gradio UI
111
- # ----------------------------
112
- title = "Image Captioning"
113
- description = (
114
- "Upload an image and the model will generate a short descriptive caption. "
115
- "Model: <b>{}</b>. ".format(MODEL_ID)
116
- )
117
-
118
- examples = [
119
- # If you want, place example image paths here (local files in repo), or leave empty.
120
- # ["examples/cat.jpg"],
121
- ]
122
-
123
- with gr.Blocks(title=title) as demo:
124
- gr.Markdown(f"# {title}")
125
- gr.Markdown(description)
126
-
127
- # Status row
128
- with gr.Row():
129
- model_info = gr.Textbox(label="Model", value=MODEL_ID, interactive=False)
130
- device_info = gr.Textbox(label="Device", value=("cuda" if torch.cuda.is_available() else "cpu"), interactive=False)
131
- status_info = gr.Textbox(label="Model status", value=("loaded" if caption_pipe is not None and not _load_error else f"error: {_load_error}" if _load_error else "loading"), interactive=False)
132
-
133
- gr.Markdown("## Upload image")
134
- with gr.Row():
135
- image_input = gr.Image(type="pil", label="Image", tool="editor")
136
- with gr.Column():
137
- run_btn = gr.Button("Generate Caption")
138
- clear_btn = gr.Button("Clear")
139
- gr.Markdown("**Tips:** use clear photos; try different crops in the editor for better captions.")
140
-
141
- output = gr.Textbox(label="Caption", interactive=False)
142
-
143
- # Example images (optional)
144
- if examples:
145
- gr.Examples(examples=examples, inputs=image_input, label="Examples")
146
-
147
- # Actions
148
- run_btn.click(fn=caption_image, inputs=image_input, outputs=output)
149
- clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, output])
150
-
151
- gr.Markdown("---")
152
- gr.Markdown("**Notes**: If the model is private, set `HUGGINGFACE_HUB_TOKEN` environment variable. "
153
- "For large models you may need GPU and more memory.")
154
-
155
- # ----------------------------
156
- # Launch
157
- # ----------------------------
158
- if __name__ == "__main__":
159
- # Respect PORT env var (used by Hugging Face Spaces)
160
- port = int(os.environ.get("PORT", 7860))
161
- demo.launch(server_name="0.0.0.0", server_port=port, share=False)
 
1
+ # app.py
2
+ """
3
+ Image Captioning demo with Gradio + Hugging Face transformers.
4
+
5
+ Environment variables:
6
+ MODEL_ID - huggingface model id (default: Salesforce/blip-image-captioning-base)
7
+ TRUST_REMOTE_CODE - "true"/"false" to allow custom repo code (default: false)
8
+ HUGGINGFACE_HUB_TOKEN - optional, if your model is private
9
+ """
10
+ import os
11
+ import logging
12
+ from typing import Optional
13
+
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import pipeline
17
+ import gradio as gr
18
+
19
+ # ----------------------------
20
+ # Configuration & logging
21
+ # ----------------------------
22
+ MODEL_ID = os.environ.get("MODEL_ID", "Salesforce/blip-image-captioning-base")
23
+ TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1", "true", "yes")
24
+ HUB_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") # optional (for private models)
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger("image-caption-gradio")
28
+
29
+ logger.info("Gradio version: %s", gr.__version__)
30
+
31
+ # ----------------------------
32
+ # Device helper
33
+ # ----------------------------
34
+ def get_pipeline_device() -> int:
35
+ """Return device index for transformers pipeline: 0 (GPU) or -1 (CPU)."""
36
+ return 0 if torch.cuda.is_available() else -1
37
+
38
+
39
+ # ----------------------------
40
+ # Load pipeline (global)
41
+ # ----------------------------
42
+ caption_pipe = None
43
+ _load_error: Optional[str] = None
44
+
45
+ def load_caption_pipeline():
46
+ """Load the image-to-text pipeline once (global)."""
47
+ global caption_pipe, _load_error
48
+ if caption_pipe is not None or _load_error:
49
+ return
50
+
51
+ device = get_pipeline_device()
52
+ logger.info("Loading model '%s' (trust_remote_code=%s) on device %s", MODEL_ID, TRUST_REMOTE_CODE, device)
53
+ try:
54
+ caption_pipe = pipeline(
55
+ "image-to-text",
56
+ model=MODEL_ID,
57
+ device=device,
58
+ trust_remote_code=TRUST_REMOTE_CODE,
59
+ )
60
+ logger.info("Model loaded successfully.")
61
+ except Exception as e:
62
+ _load_error = str(e)
63
+ logger.exception("Failed to load model: %s", e)
64
+
65
+
66
+ # Preload model at startup (best-effort)
67
+ load_caption_pipeline()
68
+
69
+
70
+ # ----------------------------
71
+ # Inference function used by Gradio
72
+ # ----------------------------
73
+ def caption_image(img: Image.Image) -> str:
74
+ """Run the captioning pipeline on a PIL image and return the caption text."""
75
+ if _load_error:
76
+ return f"Error loading model: {_load_error}"
77
+ if caption_pipe is None:
78
+ load_caption_pipeline()
79
+ if caption_pipe is None:
80
+ return "Model not loaded yet. Please try again in a moment."
81
+
82
+ try:
83
+ outputs = caption_pipe(img)
84
+ if isinstance(outputs, list) and outputs:
85
+ caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
86
+ else:
87
+ caption = str(outputs)
88
+ return caption.strip()
89
+ except Exception as e:
90
+ logger.exception("Captioning error: %s", e)
91
+ return f"Captioning failed: {e}"
92
+
93
+
94
+ # ----------------------------
95
+ # Gradio UI
96
+ # ----------------------------
97
+ title = "Image Captioning"
98
+ description = (
99
+ "Upload an image and the model will generate a short descriptive caption.\n"
100
+ f"Model: {MODEL_ID}"
101
+ )
102
+
103
+ with gr.Blocks(title=title) as demo:
104
+ gr.Markdown(f"# {title}")
105
+ gr.Markdown(description)
106
+
107
+ with gr.Row():
108
+ model_info = gr.Textbox(label="Model", value=MODEL_ID, interactive=False)
109
+ device_info = gr.Textbox(label="Device", value=("cuda" if torch.cuda.is_available() else "cpu"), interactive=False)
110
+ status_info = gr.Textbox(label="Model status", value=("loaded" if caption_pipe is not None and not _load_error else f\"error: {_load_error}\" if _load_error else "loading"), interactive=False)
111
+
112
+ gr.Markdown("## Upload image")
113
+ with gr.Row():
114
+ # NOTE: removed 'tool' kw to support more Gradio versions
115
+ image_input = gr.Image(type="pil", label="Image")
116
+ with gr.Column():
117
+ run_btn = gr.Button("Generate Caption")
118
+ clear_btn = gr.Button("Clear")
119
+
120
+ output = gr.Textbox(label="Caption", interactive=False)
121
+
122
+ run_btn.click(fn=caption_image, inputs=image_input, outputs=output)
123
+
124
+ # clear button: resets output and clears image (Gradio sometimes resets image by returning None)
125
+ def _clear():
126
+ return None, ""
127
+ clear_btn.click(fn=_clear, inputs=None, outputs=[image_input, output])
128
+
129
+ gr.Markdown("---")
130
+ gr.Markdown("Notes: If the model is private, set HUGGINGFACE_HUB_TOKEN. Large models need more memory/GPU.")
131
+
132
+ # ----------------------------
133
+ # Launch
134
+ # ----------------------------
135
+ if __name__ == "__main__":
136
+ port = int(os.environ.get("PORT", 7860))
137
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False)