AutoCaptioner / app.py
dlaima's picture
Update app.py
ee8a4e6 verified
raw
history blame
2.94 kB
from dotenv import load_dotenv, find_dotenv
import os
import io
from PIL import Image
import requests
import warnings
import gradio as gr
# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*Using the model-agnostic default `max_length`.*")
# Load environment variables from .env file
load_dotenv(find_dotenv())
hf_api_key = os.getenv('HF_API_KEY')
endpoint_url = os.getenv('HF_API_ITT_BASE')
# Validate environment variables
if not hf_api_key:
raise ValueError("HF_API_KEY is not set in the .env file.")
if not endpoint_url:
raise ValueError("HF_API_ITT_BASE is not set in the .env file.")
def generate_caption(image):
"""
Sends an image to the Hugging Face Inference Endpoint for caption generation.
:param image: An image in PIL format.
:return: Generated caption or error message.
"""
try:
headers = {"Authorization": f"Bearer {hf_api_key}"}
# Convert the PIL image to a binary stream in JPEG format
buffered = io.BytesIO()
image = image.convert("RGB") # Ensure the image is in RGB mode
image.save(buffered, format="JPEG")
buffered.seek(0)
# Create the appropriate payload for the API
files = {"file": ("image.jpg", buffered, "image/jpeg")}
# Make the POST request to the endpoint
response = requests.post(endpoint_url, headers=headers, files=files)
if response.status_code == 200:
return response.json().get("generated_text", "No caption generated.")
else:
# Log the error response for debugging
return (
f"Error: {response.status_code} - {response.text}\n"
f"Headers: {headers}\nEndpoint: {endpoint_url}"
)
except Exception as e:
return f"An error occurred: {str(e)}"
# Predefined sample images
def get_sample_images():
"""
Returns a list of predefined sample images in the assets directory.
"""
sample_dir = "CreatureCaptures" # Ensure this directory exists and contains sample images
try:
return [
os.path.join(sample_dir, file)
for file in os.listdir(sample_dir)
if file.lower().endswith((".png", ".jpg", ".jpeg"))
]
except FileNotFoundError:
return []
# Gradio interface
sample_images = get_sample_images() # Load predefined sample images
demo = gr.Interface(
fn=generate_caption,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Textbox(label="Generated Caption"),
examples=sample_images,
title="Image Captioning App",
description=(
"Upload an image or use one of the predefined samples to generate a caption. "
"This app uses a Hugging Face Inference Endpoint for the `Salesforce/blip-image-captioning-base` model."
),
)
if __name__ == "__main__":
# Launch the Gradio demo
demo.launch()