moslem commited on
Commit
3d4fa97
·
verified ·
1 Parent(s): 5591cd5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +161 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ """
3
+ Image Captioning demo with Gradio + Hugging Face transformers.
4
+
5
+ Environment variables:
6
+ MODEL_ID - huggingface model id (default: Salesforce/blip-image-captioning-base)
7
+ TRUST_REMOTE_CODE - "true"/"false" to allow custom repo code (default: false)
8
+ HUGGINGFACE_HUB_TOKEN - optional, if your model is private
9
+
10
+ Run:
11
+ python app.py
12
+ """
13
+ import os
14
+ import logging
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from PIL import Image
19
+ from transformers import pipeline
20
+ import gradio as gr
21
+
22
+ # ----------------------------
23
+ # Configuration & logging
24
+ # ----------------------------
25
+ MODEL_ID = os.environ.get("MODEL_ID", "Salesforce/blip-image-captioning-base")
26
+ TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1", "true", "yes")
27
+ HUB_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") # optional (for private models)
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger("image-caption-gradio")
31
+
32
+ # ----------------------------
33
+ # Device helper
34
+ # ----------------------------
35
+ def get_pipeline_device() -> int:
36
+ """
37
+ Return device index for transformers pipeline:
38
+ 0 (GPU) if available, else -1 (CPU)
39
+ """
40
+ return 0 if torch.cuda.is_available() else -1
41
+
42
+
43
+ # ----------------------------
44
+ # Load pipeline (global)
45
+ # ----------------------------
46
+ caption_pipe = None
47
+ _load_error: Optional[str] = None
48
+
49
+ def load_caption_pipeline():
50
+ """
51
+ Load the image-to-text pipeline once (global).
52
+ Uses HUGGINGFACE_HUB_TOKEN if set for private models.
53
+ """
54
+ global caption_pipe, _load_error
55
+ if caption_pipe is not None or _load_error:
56
+ return
57
+
58
+ device = get_pipeline_device()
59
+ logger.info("Loading model '%s' (trust_remote_code=%s) on device %s", MODEL_ID, TRUST_REMOTE_CODE, device)
60
+
61
+ try:
62
+ # If HUB_TOKEN is provided, transformers/huggingface_hub will pick it up from env.
63
+ caption_pipe = pipeline(
64
+ "image-to-text",
65
+ model=MODEL_ID,
66
+ device=device,
67
+ trust_remote_code=TRUST_REMOTE_CODE,
68
+ )
69
+ logger.info("Model loaded successfully.")
70
+ except Exception as e:
71
+ _load_error = str(e)
72
+ logger.exception("Failed to load model: %s", e)
73
+
74
+
75
+ # Preload model at startup (best-effort)
76
+ load_caption_pipeline()
77
+
78
+
79
+ # ----------------------------
80
+ # Inference function used by Gradio
81
+ # ----------------------------
82
+ def caption_image(img: Image.Image) -> str:
83
+ """
84
+ Run the captioning pipeline on a PIL image and return the caption text.
85
+ """
86
+ if _load_error:
87
+ # If loading failed earlier, return the error for the UI
88
+ return f"Error loading model: {_load_error}"
89
+
90
+ if caption_pipe is None:
91
+ # Try loading lazily if not loaded yet
92
+ load_caption_pipeline()
93
+ if caption_pipe is None:
94
+ return f"Model not loaded. Try again in a moment."
95
+
96
+ try:
97
+ outputs = caption_pipe(img)
98
+ # pipeline usually returns a list of dicts with 'generated_text'
99
+ if isinstance(outputs, list) and outputs:
100
+ caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
101
+ else:
102
+ caption = str(outputs)
103
+ return caption.strip()
104
+ except Exception as e:
105
+ logger.exception("Captioning error: %s", e)
106
+ return f"Captioning failed: {e}"
107
+
108
+
109
+ # ----------------------------
110
+ # Gradio UI
111
+ # ----------------------------
112
+ title = "Image Captioning"
113
+ description = (
114
+ "Upload an image and the model will generate a short descriptive caption. "
115
+ "Model: <b>{}</b>. ".format(MODEL_ID)
116
+ )
117
+
118
+ examples = [
119
+ # If you want, place example image paths here (local files in repo), or leave empty.
120
+ # ["examples/cat.jpg"],
121
+ ]
122
+
123
+ with gr.Blocks(title=title) as demo:
124
+ gr.Markdown(f"# {title}")
125
+ gr.Markdown(description)
126
+
127
+ # Status row
128
+ with gr.Row():
129
+ model_info = gr.Textbox(label="Model", value=MODEL_ID, interactive=False)
130
+ device_info = gr.Textbox(label="Device", value=("cuda" if torch.cuda.is_available() else "cpu"), interactive=False)
131
+ status_info = gr.Textbox(label="Model status", value=("loaded" if caption_pipe is not None and not _load_error else f"error: {_load_error}" if _load_error else "loading"), interactive=False)
132
+
133
+ gr.Markdown("## Upload image")
134
+ with gr.Row():
135
+ image_input = gr.Image(type="pil", label="Image", tool="editor")
136
+ with gr.Column():
137
+ run_btn = gr.Button("Generate Caption")
138
+ clear_btn = gr.Button("Clear")
139
+ gr.Markdown("**Tips:** use clear photos; try different crops in the editor for better captions.")
140
+
141
+ output = gr.Textbox(label="Caption", interactive=False)
142
+
143
+ # Example images (optional)
144
+ if examples:
145
+ gr.Examples(examples=examples, inputs=image_input, label="Examples")
146
+
147
+ # Actions
148
+ run_btn.click(fn=caption_image, inputs=image_input, outputs=output)
149
+ clear_btn.click(fn=lambda: (None, ""), inputs=None, outputs=[image_input, output])
150
+
151
+ gr.Markdown("---")
152
+ gr.Markdown("**Notes**: If the model is private, set `HUGGINGFACE_HUB_TOKEN` environment variable. "
153
+ "For large models you may need GPU and more memory.")
154
+
155
+ # ----------------------------
156
+ # Launch
157
+ # ----------------------------
158
+ if __name__ == "__main__":
159
+ # Respect PORT env var (used by Hugging Face Spaces)
160
+ port = int(os.environ.get("PORT", 7860))
161
+ demo.launch(server_name="0.0.0.0", server_port=port, share=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=3.40.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ pillow>=9.0.0
5
+ # Optional extras (uncomment if needed by the model)
6
+ # accelerate>=0.20.3
7
+ # diffusers>=0.11.0
8
+ # safetensors>=0.3.0