dlaima commited on
Commit
705859f
·
verified ·
1 Parent(s): ee8a4e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import io
4
  from PIL import Image
5
  import requests
 
6
  import warnings
7
  import gradio as gr
8
 
@@ -24,23 +25,25 @@ if not endpoint_url:
24
  def generate_caption(image):
25
  """
26
  Sends an image to the Hugging Face Inference Endpoint for caption generation.
 
27
  :param image: An image in PIL format.
28
  :return: Generated caption or error message.
29
  """
30
  try:
31
  headers = {"Authorization": f"Bearer {hf_api_key}"}
32
 
33
- # Convert the PIL image to a binary stream in JPEG format
34
  buffered = io.BytesIO()
35
  image = image.convert("RGB") # Ensure the image is in RGB mode
36
  image.save(buffered, format="JPEG")
37
  buffered.seek(0)
 
38
 
39
- # Create the appropriate payload for the API
40
- files = {"file": ("image.jpg", buffered, "image/jpeg")}
41
 
42
  # Make the POST request to the endpoint
43
- response = requests.post(endpoint_url, headers=headers, files=files)
44
 
45
  if response.status_code == 200:
46
  return response.json().get("generated_text", "No caption generated.")
@@ -88,3 +91,4 @@ if __name__ == "__main__":
88
  # Launch the Gradio demo
89
  demo.launch()
90
 
 
 
3
  import io
4
  from PIL import Image
5
  import requests
6
+ import base64
7
  import warnings
8
  import gradio as gr
9
 
 
25
  def generate_caption(image):
26
  """
27
  Sends an image to the Hugging Face Inference Endpoint for caption generation.
28
+ Uses base64 encoding for compatibility.
29
  :param image: An image in PIL format.
30
  :return: Generated caption or error message.
31
  """
32
  try:
33
  headers = {"Authorization": f"Bearer {hf_api_key}"}
34
 
35
+ # Convert the image to RGB and encode it in base64
36
  buffered = io.BytesIO()
37
  image = image.convert("RGB") # Ensure the image is in RGB mode
38
  image.save(buffered, format="JPEG")
39
  buffered.seek(0)
40
+ image_base64 = base64.b64encode(buffered.read()).decode("utf-8")
41
 
42
+ # Prepare the JSON payload
43
+ payload = {"inputs": image_base64}
44
 
45
  # Make the POST request to the endpoint
46
+ response = requests.post(endpoint_url, headers=headers, json=payload)
47
 
48
  if response.status_code == 200:
49
  return response.json().get("generated_text", "No caption generated.")
 
91
  # Launch the Gradio demo
92
  demo.launch()
93
 
94
+