moslem commited on
Commit
0b8d5a4
·
verified ·
1 Parent(s): 2ce6e8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -128
app.py CHANGED
@@ -1,137 +1,52 @@
1
- # app.py
2
- """
3
- Image Captioning demo with Gradio + Hugging Face transformers.
4
-
5
- Environment variables:
6
- MODEL_ID - huggingface model id (default: Salesforce/blip-image-captioning-base)
7
- TRUST_REMOTE_CODE - "true"/"false" to allow custom repo code (default: false)
8
- HUGGINGFACE_HUB_TOKEN - optional, if your model is private
9
- """
10
- import os
11
- import logging
12
- from typing import Optional
13
-
14
- import torch
15
- from PIL import Image
16
- from transformers import pipeline
17
  import gradio as gr
18
-
19
- # ----------------------------
20
- # Configuration & logging
21
- # ----------------------------
22
- MODEL_ID = os.environ.get("MODEL_ID", "Salesforce/blip-image-captioning-base")
23
- TRUST_REMOTE_CODE = os.environ.get("TRUST_REMOTE_CODE", "false").lower() in ("1", "true", "yes")
24
- HUB_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN") # optional (for private models)
25
-
26
- logging.basicConfig(level=logging.INFO)
27
- logger = logging.getLogger("image-caption-gradio")
28
-
29
- logger.info("Gradio version: %s", gr.__version__)
30
-
31
- # ----------------------------
32
- # Device helper
33
- # ----------------------------
34
- def get_pipeline_device() -> int:
35
- """Return device index for transformers pipeline: 0 (GPU) or -1 (CPU)."""
36
- return 0 if torch.cuda.is_available() else -1
37
-
38
-
39
- # ----------------------------
40
- # Load pipeline (global)
41
- # ----------------------------
42
- caption_pipe = None
43
- _load_error: Optional[str] = None
44
-
45
- def load_caption_pipeline():
46
- """Load the image-to-text pipeline once (global)."""
47
- global caption_pipe, _load_error
48
- if caption_pipe is not None or _load_error:
49
- return
50
-
51
- device = get_pipeline_device()
52
- logger.info("Loading model '%s' (trust_remote_code=%s) on device %s", MODEL_ID, TRUST_REMOTE_CODE, device)
53
- try:
54
- caption_pipe = pipeline(
55
- "image-to-text",
56
- model=MODEL_ID,
57
- device=device,
58
- trust_remote_code=TRUST_REMOTE_CODE,
59
- )
60
- logger.info("Model loaded successfully.")
61
- except Exception as e:
62
- _load_error = str(e)
63
- logger.exception("Failed to load model: %s", e)
64
-
65
-
66
- # Preload model at startup (best-effort)
67
- load_caption_pipeline()
68
-
69
-
70
- # ----------------------------
71
- # Inference function used by Gradio
72
- # ----------------------------
73
- def caption_image(img: Image.Image) -> str:
74
- """Run the captioning pipeline on a PIL image and return the caption text."""
75
  if _load_error:
76
- return f"Error loading model: {_load_error}"
77
- if caption_pipe is None:
78
- load_caption_pipeline()
79
- if caption_pipe is None:
80
- return "Model not loaded yet. Please try again in a moment."
81
-
82
- try:
83
- outputs = caption_pipe(img)
84
- if isinstance(outputs, list) and outputs:
85
- caption = outputs[0].get("generated_text") or outputs[0].get("caption") or str(outputs[0])
86
- else:
87
- caption = str(outputs)
88
- return caption.strip()
89
- except Exception as e:
90
- logger.exception("Captioning error: %s", e)
91
- return f"Captioning failed: {e}"
92
-
93
-
94
- # ----------------------------
95
- # Gradio UI
96
- # ----------------------------
97
- title = "Image Captioning"
98
- description = (
99
- "Upload an image and the model will generate a short descriptive caption.\n"
100
- f"Model: {MODEL_ID}"
101
  )
102
 
103
- with gr.Blocks(title=title) as demo:
104
- gr.Markdown(f"# {title}")
105
- gr.Markdown(description)
106
-
107
- with gr.Row():
108
- model_info = gr.Textbox(label="Model", value=MODEL_ID, interactive=False)
109
- device_info = gr.Textbox(label="Device", value=("cuda" if torch.cuda.is_available() else "cpu"), interactive=False)
110
- status_info = gr.Textbox(label="Model status", value=("loaded" if caption_pipe is not None and not _load_error else f\"error: {_load_error}\" if _load_error else "loading"), interactive=False)
111
 
112
- gr.Markdown("## Upload image")
113
  with gr.Row():
114
- # NOTE: removed 'tool' kw to support more Gradio versions
115
- image_input = gr.Image(type="pil", label="Image")
116
- with gr.Column():
117
- run_btn = gr.Button("Generate Caption")
118
- clear_btn = gr.Button("Clear")
119
-
120
- output = gr.Textbox(label="Caption", interactive=False)
121
-
122
- run_btn.click(fn=caption_image, inputs=image_input, outputs=output)
123
 
124
- # clear button: resets output and clears image (Gradio sometimes resets image by returning None)
125
- def _clear():
126
- return None, ""
127
- clear_btn.click(fn=_clear, inputs=None, outputs=[image_input, output])
128
 
129
- gr.Markdown("---")
130
- gr.Markdown("Notes: If the model is private, set HUGGINGFACE_HUB_TOKEN. Large models need more memory/GPU.")
131
 
132
- # ----------------------------
133
- # Launch
134
- # ----------------------------
135
- if __name__ == "__main__":
136
- port = int(os.environ.get("PORT", 7860))
137
- demo.launch(server_name="0.0.0.0", server_port=port, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
+
5
+ MODEL_NAME = "Salesforce/blip-image-captioning-base"
6
+
7
+ # --- مدل را بارگیری کن
8
+ try:
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ caption_processor = BlipProcessor.from_pretrained(MODEL_NAME)
11
+ caption_model = BlipForConditionalGeneration.from_pretrained(MODEL_NAME)
12
+ caption_model.to(device)
13
+ _load_error = None
14
+ except Exception as e:
15
+ caption_processor = None
16
+ caption_model = None
17
+ _load_error = str(e)
18
+
19
+ # --- تابع captioning
20
+ def caption_image(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if _load_error:
22
+ return f" Model load error: {_load_error}"
23
+ if image is None:
24
+ return "⚠️ لطفاً یک تصویر آپلود کنید."
25
+
26
+ inputs = caption_processor(image, return_tensors="pt").to(device)
27
+ out = caption_model.generate(**inputs, max_new_tokens=30)
28
+ caption = caption_processor.decode(out[0], skip_special_tokens=True)
29
+ return caption
30
+
31
+ # --- رابط کاربری Gradio
32
+ status_text = (
33
+ "✅ Model loaded successfully"
34
+ if caption_model is not None and not _load_error
35
+ else f"❌ Error: {_load_error}"
36
+ if _load_error
37
+ else " Loading model..."
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
+ with gr.Blocks(title="Image Captioning App") as demo:
41
+ gr.Markdown("## 🖼️ Image Captioning with BLIP\nUpload an image and get an automatic caption.")
42
+ gr.Markdown(f"**Status:** {status_text}")
 
 
 
 
 
43
 
 
44
  with gr.Row():
45
+ image_input = gr.Image(type="pil", label="Upload Image")
46
+ caption_output = gr.Textbox(label="Generated Caption", interactive=False)
 
 
 
 
 
 
 
47
 
48
+ generate_btn = gr.Button("Generate Caption")
 
 
 
49
 
50
+ generate_btn.click(fn=caption_image, inputs=image_input, outputs=caption_output)
 
51
 
52
+ demo.launch()