dlaima commited on
Commit
d907d8c
·
verified ·
1 Parent(s): 8332de3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -40
app.py CHANGED
@@ -1,72 +1,70 @@
1
  from dotenv import load_dotenv, find_dotenv
2
- load_dotenv(find_dotenv())
3
-
4
- import os # Provides a way of using operating system-dependent functionality
5
- import io # Provides core tools for working with streams of data
6
  from io import BytesIO
7
- import IPython.display # Used for displaying rich content (e.g., images, HTML) in Jupyter Notebooks
8
- from PIL import Image # Python Imaging Library for opening, manipulating, and saving image files
9
- import base64 # Encodes and decodes data in base64 format
10
  import requests
11
  import json
12
- import torch
13
- import torch.nn as nn
14
  import warnings
15
  import gradio as gr
16
 
17
- # Ignore specific UserWarnings related to max_length in transformers
18
  warnings.filterwarnings("ignore", message=".*Using the model-agnostic default `max_length`.*")
19
 
20
  # Load environment variables from .env file
 
21
  hf_api_key = os.getenv('API_TOKEN')
22
  endpoint_url = os.getenv('INFERENCE_ENDPOINT')
23
 
 
24
 
25
- # Set your Inference Endpoint URL and API key
26
- #INFERENCE_ENDPOINT = "https://your-endpoint-url" # Replace with your endpoint URL
27
- #API_TOKEN = "your-api-token" # Replace with your Hugging Face API token
28
-
29
- #Image-to-text endpoint - Helper funcion
30
- def get_completion(inputs, parameters=None, endpoint_url=endpoint_url):
31
  headers = {
32
  "Authorization": f"Bearer {hf_api_key}",
33
  "Content-Type": "application/json"
34
  }
35
- data = {"inputs": inputs}
 
 
 
 
 
36
  if parameters is not None:
37
  data.update({"parameters": parameters})
 
38
  response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
39
- return json.loads(response.content.decode("utf-8"))
40
 
41
- def get_generation(model, processor, image, dtype):
42
- inputs = processor(image, return_tensors="pt").to(dtype)
43
- out = model.generate(**inputs)
44
- return processor.decode(out[0], skip_special_tokens=True)
45
 
46
- def load_image(img_url):
47
- image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
48
- return image
49
-
50
- #Gradio interface
51
  def caption_image(image_url):
52
- # Download the image from the URL
53
- response = requests.get(image_url)
54
- response.raise_for_status() # Ensure the request was successful
55
- image = Image.open(BytesIO(response.content)) # Load image with PIL
56
-
57
- # Call your captioning function here (replace `get_completion` with the actual implementation)
58
- #caption = get_completion(image)
59
- caption = get_completion(image_url)
60
- return caption
61
 
62
- # Gradio interface
 
 
 
 
63
 
 
 
 
 
64
  demo = gr.Interface(
65
  fn=caption_image,
66
- inputs=gr.Textbox(label="Image URL"), # Input as a URL
67
  outputs="text",
68
- #examples=[Image1, Image2, Image3],
69
- #examples=[image],
70
  title="Image Captioning App",
71
  description=(
72
  "Upload an image or use one of the predefined samples to generate a caption. "
@@ -76,3 +74,4 @@ demo = gr.Interface(
76
 
77
  if __name__ == "__main__":
78
  demo.launch()
 
 
1
  from dotenv import load_dotenv, find_dotenv
2
+ import os
3
+ import io
 
 
4
  from io import BytesIO
5
+ from PIL import Image
6
+ import base64
 
7
  import requests
8
  import json
 
 
9
  import warnings
10
  import gradio as gr
11
 
12
+ # Suppress specific warnings
13
  warnings.filterwarnings("ignore", message=".*Using the model-agnostic default `max_length`.*")
14
 
15
  # Load environment variables from .env file
16
+ load_dotenv(find_dotenv())
17
  hf_api_key = os.getenv('API_TOKEN')
18
  endpoint_url = os.getenv('INFERENCE_ENDPOINT')
19
 
20
+ # Helper function for image-to-text API
21
 
22
+ def get_completion(image, parameters=None, endpoint_url=endpoint_url):
 
 
 
 
 
23
  headers = {
24
  "Authorization": f"Bearer {hf_api_key}",
25
  "Content-Type": "application/json"
26
  }
27
+ # Convert image to base64 format
28
+ buffered = BytesIO()
29
+ image.save(buffered, format="JPEG")
30
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
31
+
32
+ data = {"inputs": {"image": image_base64}}
33
  if parameters is not None:
34
  data.update({"parameters": parameters})
35
+
36
  response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
 
37
 
38
+ # Check for errors
39
+ if response.status_code != 200:
40
+ return {"error": response.text}
 
41
 
42
+ return json.loads(response.content.decode("utf-8"))
43
+
44
+ # Helper function to download and process the image from a URL
 
 
45
  def caption_image(image_url):
46
+ try:
47
+ response = requests.get(image_url)
48
+ response.raise_for_status()
49
+ image = Image.open(BytesIO(response.content)).convert("RGB")
50
+
51
+ # Get caption from API
52
+ caption_response = get_completion(image)
 
 
53
 
54
+ # Handle API response
55
+ if "error" in caption_response:
56
+ return f"Error: {caption_response['error']}"
57
+
58
+ return caption_response.get("generated_text", "No caption generated.")
59
 
60
+ except Exception as e:
61
+ return f"Error processing image: {str(e)}"
62
+
63
+ # Gradio interface
64
  demo = gr.Interface(
65
  fn=caption_image,
66
+ inputs=gr.Textbox(label="Image URL"),
67
  outputs="text",
 
 
68
  title="Image Captioning App",
69
  description=(
70
  "Upload an image or use one of the predefined samples to generate a caption. "
 
74
 
75
  if __name__ == "__main__":
76
  demo.launch()
77
+