LiKenun commited on
Commit
d56b9d9
·
1 Parent(s): dc382c8

Add image captioning sample

Browse files
Files changed (5) hide show
  1. app.py +31 -22
  2. image_classification.py +11 -32
  3. image_to_text.py +19 -0
  4. requirements.txt +5 -1
  5. utils.py +33 -1
app.py CHANGED
@@ -3,7 +3,9 @@ from functools import partial
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from image_classification import image_classification
 
6
  from text_to_image import text_to_image
 
7
 
8
 
9
  class App:
@@ -18,7 +20,7 @@ class App:
18
  with gr.Tabs():
19
  with gr.Tab("Text-to-image Generation"):
20
  gr.Markdown("Generate an image from a text prompt.")
21
- text_to_image_prompt = gr.Textbox(label="Prompt", value="A panda under a giant mushroom next to a pumpkin")
22
  text_to_image_generate_button = gr.Button("Generate")
23
  text_to_image_output = gr.Image(label="Image", type="pil")
24
  text_to_image_generate_button.click(
@@ -26,32 +28,39 @@ class App:
26
  inputs=text_to_image_prompt,
27
  outputs=text_to_image_output
28
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with gr.Tab("Image Classification"):
30
  gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
31
- with gr.Row():
32
- with gr.Column():
33
- image_classification_url_input = gr.Textbox(
34
- label="Image URL",
35
- value="https://campuslifeservices.ucsf.edu/upload/facilities/galleries/cardboard_0.jpg",
36
- placeholder="Enter the URL of the image to classify",
37
- scale=2
38
- )
39
- image_classification_image_preview = gr.Image(label="Image Preview", type="pil")
40
- image_classification_upload_input = gr.Image(
41
- label="Or Upload Image",
42
- type="pil",
43
- scale=2
44
- )
45
- image_classification_button = gr.Button("Classify")
46
- image_classification_output = gr.Dataframe(
47
- label="Classification Results",
48
- headers=["Label", "Probability"],
49
- interactive=False
50
  )
 
 
51
  image_classification_button.click(
52
  fn=partial(image_classification, self.client),
53
- inputs=[image_classification_url_input, image_classification_upload_input],
54
- outputs=[image_classification_image_preview, image_classification_output]
55
  )
56
 
57
  demo.launch()
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from image_classification import image_classification
6
+ from image_to_text import image_to_text
7
  from text_to_image import text_to_image
8
+ from utils import request_image
9
 
10
 
11
  class App:
 
20
  with gr.Tabs():
21
  with gr.Tab("Text-to-image Generation"):
22
  gr.Markdown("Generate an image from a text prompt.")
23
+ text_to_image_prompt = gr.Textbox(label="Prompt")
24
  text_to_image_generate_button = gr.Button("Generate")
25
  text_to_image_output = gr.Image(label="Image", type="pil")
26
  text_to_image_generate_button.click(
 
28
  inputs=text_to_image_prompt,
29
  outputs=text_to_image_output
30
  )
31
+ with gr.Tab("Image-to-text or Image Captioning"):
32
+ gr.Markdown("Generate a text description of an image.")
33
+ image_to_text_url_input = gr.Textbox(label="Image URL")
34
+ image_to_text_image_request_button = gr.Button("Get Image")
35
+ image_to_text_image_input = gr.Image(label="Image", type="pil")
36
+ image_to_text_image_request_button.click(
37
+ fn=request_image,
38
+ inputs=image_to_text_url_input,
39
+ outputs=image_to_text_image_input
40
+ )
41
+ image_to_text_output = gr.List(label="Captions", headers=["Caption"])
42
+ image_to_text_button = gr.Button("Caption")
43
+ image_to_text_button.click(
44
+ fn=image_to_text,
45
+ inputs=image_to_text_image_input,
46
+ outputs=image_to_text_output
47
+ )
48
  with gr.Tab("Image Classification"):
49
  gr.Markdown("Classify a recyclable item as one of: cardboard, glass, metal, paper, plastic, or other using [Trash-Net](https://huggingface.co/prithivMLmods/Trash-Net).")
50
+ image_classification_url_input = gr.Textbox(label="Image URL")
51
+ image_classification_image_request_button = gr.Button("Get Image")
52
+ image_classification_image_input = gr.Image(label="Image",type="pil")
53
+ image_classification_image_request_button.click(
54
+ fn=request_image,
55
+ inputs=image_classification_url_input,
56
+ outputs=image_classification_image_input
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ image_classification_button = gr.Button("Classify")
59
+ image_classification_output = gr.Dataframe(label="Classification", headers=["Label", "Probability"], interactive=False)
60
  image_classification_button.click(
61
  fn=partial(image_classification, self.client),
62
+ inputs=image_classification_image_input,
63
+ outputs=image_classification_output
64
  )
65
 
66
  demo.launch()
image_classification.py CHANGED
@@ -1,44 +1,23 @@
1
- import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from io import BytesIO
4
  from os import path, unlink, getenv
5
- from PIL.Image import Image, open as open_image
6
  import pandas as pd
7
  from pandas import DataFrame
8
- import requests
9
  from utils import save_image_to_temp_file
10
 
11
 
12
- def image_classification(client: InferenceClient, image_url: str | None, image: Image | None) -> tuple[Image | None, DataFrame]:
13
- temp_file_path = None
14
  try:
15
- if image is not None and image_url and image_url.strip():
16
- raise gr.Error("Both an image URL and an uploaded image were provided. Please provide only one or the other.")
17
- elif image is not None:
18
- temp_file_path = save_image_to_temp_file(image)
19
- classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
20
- image = None
21
- elif image_url and image_url.strip():
22
- try:
23
- response = requests.get(image_url, timeout=int(getenv("REQUEST_TIMEOUT")))
24
- response.raise_for_status()
25
- image = open_image(BytesIO(response.content))
26
- temp_file_path = save_image_to_temp_file(image)
27
- classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
28
- except Exception as e:
29
- raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
30
- else:
31
- raise gr.Error("Please either provide an image URL or upload an image.")
32
- df = pd.DataFrame({
33
- "Label": classification.label,
34
- "Probability": f"{classification.score:.2%}"
35
- }
36
- for classification
37
- in classifications)
38
- return image, df
39
  finally:
40
- # Clean up temporary file.
41
- if temp_file_path and path.exists(temp_file_path):
42
  try:
43
  unlink(temp_file_path)
44
  except Exception:
 
 
1
  from huggingface_hub import InferenceClient
 
2
  from os import path, unlink, getenv
3
+ from PIL.Image import Image
4
  import pandas as pd
5
  from pandas import DataFrame
 
6
  from utils import save_image_to_temp_file
7
 
8
 
9
+ def image_classification(client: InferenceClient, image: Image) -> DataFrame:
 
10
  try:
11
+ temp_file_path = save_image_to_temp_file(image) # Needed because InferenceClient does not accept PIL Images directly.
12
+ classifications = client.image_classification(temp_file_path, model=getenv("IMAGE_CLASSIFICATION_MODEL"))
13
+ return pd.DataFrame({
14
+ "Label": classification.label,
15
+ "Probability": f"{classification.score:.2%}"
16
+ }
17
+ for classification
18
+ in classifications)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  finally:
20
+ if temp_file_path and path.exists(temp_file_path): # Clean up temporary file.
 
21
  try:
22
  unlink(temp_file_path)
23
  except Exception:
image_to_text.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from os import getenv
3
+ from PIL.Image import Image
4
+ from transformers import AutoProcessor, BlipForConditionalGeneration
5
+ from utils import get_pytorch_device, spaces_gpu
6
+
7
+
8
+ @spaces_gpu
9
+ def image_to_text(image: Image) -> list[str]:
10
+ image_to_text_model_id = getenv("IMAGE_TO_TEXT_MODEL")
11
+ pytorch_device = get_pytorch_device()
12
+ processor = AutoProcessor.from_pretrained(image_to_text_model_id)
13
+ model = BlipForConditionalGeneration.from_pretrained(image_to_text_model_id).to(pytorch_device)
14
+ inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
15
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
16
+ results = processor.batch_decode(generated_ids, skip_special_tokens=True)
17
+ del model, inputs
18
+ gc.collect()
19
+ return results
requirements.txt CHANGED
@@ -1,6 +1,10 @@
1
  gradio>=5.49.1
2
- huggingface-hub>=1.0.1
3
  python-dotenv>=1.0.0
4
  pandas>=2.0.0
5
  pillow>=10.0.0
6
  requests>=2.31.0
 
 
 
 
 
1
  gradio>=5.49.1
2
+ huggingface-hub>=0.34.0,<1.0
3
  python-dotenv>=1.0.0
4
  pandas>=2.0.0
5
  pillow>=10.0.0
6
  requests>=2.31.0
7
+ transformers>=4.40.0
8
+ torch>=2.0.0
9
+ torchvision>=0.15.0
10
+ torchaudio>=2.0.0
utils.py CHANGED
@@ -1,7 +1,39 @@
1
- from PIL.Image import Image
 
 
 
 
2
  from tempfile import NamedTemporaryFile
 
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def save_image_to_temp_file(image: Image) -> str:
6
  image_format = image.format if image.format else 'PNG'
7
  format_extension = image_format.lower() if image_format else 'png'
 
1
+ import gradio as gr
2
+ from io import BytesIO
3
+ from PIL.Image import Image, open as open_image
4
+ from os import getenv
5
+ import requests
6
  from tempfile import NamedTemporaryFile
7
+ import torch
8
 
9
 
10
+ # Try to import spaces decorator (for Hugging Face Spaces), otherwise use no-op decorator.
11
+ try:
12
+ import spaces
13
+ spaces_gpu = spaces.GPU
14
+ except ImportError:
15
+ # For local development, use a no-op decorator because spaces is not available.
16
+ def spaces_gpu(func):
17
+ return func
18
+
19
+ def get_pytorch_device() -> str:
20
+ return ("cuda" if torch.cuda.is_available() # Nvidia CUDA and AMD ROCm
21
+ else "xpu" if torch.xpu.is_available() # Intel XPU
22
+ else "mps" if torch.mps.is_available() # Apple Silicon
23
+ else "cpu") # gl bro 🫠
24
+
25
+ def request_image(url: str) -> Image:
26
+ try:
27
+ response = requests.get(url, timeout=int(getenv("REQUEST_TIMEOUT")))
28
+ response.raise_for_status()
29
+ return open_image(BytesIO(response.content))
30
+ except requests.HTTPError as e:
31
+ raise gr.Error(f"Failed to fetch image from URL because of HTTP error: {e.response.status_code} {e.response.text}")
32
+ except requests.Timeout as e:
33
+ raise gr.Error(f"Failed to fetch image from URL because the request timed out.")
34
+ except requests.RequestException as e:
35
+ raise gr.Error(f"Failed to fetch image from URL: {str(e)}")
36
+
37
  def save_image_to_temp_file(image: Image) -> str:
38
  image_format = image.format if image.format else 'PNG'
39
  format_extension = image_format.lower() if image_format else 'png'