dlaima's picture
Update app.py
683e91e verified
raw
history blame
2.89 kB
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
import os # Provides a way of using operating system-dependent functionality
import io # Provides core tools for working with streams of data
from io import BytesIO
import IPython.display # Used for displaying rich content (e.g., images, HTML) in Jupyter Notebooks
from PIL import Image # Python Imaging Library for opening, manipulating, and saving image files
import base64 # Encodes and decodes data in base64 format
import requests
import json
import torch
import torch.nn as nn
import warnings
import gradio as gr
# Ignore specific UserWarnings related to max_length in transformers
warnings.filterwarnings("ignore", message=".*Using the model-agnostic default `max_length`.*")
# Load environment variables from .env file
hf_api_key = os.getenv('API_TOKEN')
endpoint_url = os.getenv('INFERENCE_ENDPOINT')
# Set your Inference Endpoint URL and API key
#INFERENCE_ENDPOINT = "https://your-endpoint-url" # Replace with your endpoint URL
#API_TOKEN = "your-api-token" # Replace with your Hugging Face API token
#Image-to-text endpoint - Helper funcion
def get_completion(inputs, parameters=None, endpoint_url=endpoint_url):
headers = {
"Authorization": f"Bearer {hf_api_key}",
"Content-Type": "application/json"
}
data = {"inputs": inputs}
if parameters is not None:
data.update({"parameters": parameters})
response = requests.post(endpoint_url, headers=headers, data=json.dumps(data))
return json.loads(response.content.decode("utf-8"))
def get_generation(model, processor, image, dtype):
inputs = processor(image, return_tensors="pt").to(dtype)
out = model.generate(**inputs)
return processor.decode(out[0], skip_special_tokens=True)
def load_image(img_url):
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
return image
#Gradio interface
def caption_image(image_url):
# Download the image from the URL
response = requests.get(image_url)
response.raise_for_status() # Ensure the request was successful
image = Image.open(BytesIO(response.content)) # Load image with PIL
# Call your captioning function here (replace `get_completion` with the actual implementation)
#caption = get_completion(image)
caption = get_completion(image_url)
return caption
# Gradio interface
demo = gr.Interface(
fn=caption_image,
inputs=gr.Textbox(label="Image URL"), # Input as a URL
outputs="text",
#examples=[Image1, Image2, Image3],
#examples=[image],
title="Image Captioning App",
description=(
"Upload an image or use one of the predefined samples to generate a caption. "
"This app uses a Hugging Face Inference Endpoint for the `Salesforce/blip-image-captioning-base` model."
),
)
if __name__ == "__main__":
demo.launch()