|
|
from dotenv import load_dotenv, find_dotenv |
|
|
load_dotenv(find_dotenv()) |
|
|
|
|
|
import os |
|
|
import io |
|
|
from io import BytesIO |
|
|
import IPython.display |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import requests |
|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import warnings |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", message=".*Using the model-agnostic default `max_length`.*") |
|
|
|
|
|
|
|
|
hf_api_key = os.getenv('API_TOKEN') |
|
|
endpoint_url = os.getenv('INFERENCE_ENDPOINT') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_completion(inputs, parameters=None, endpoint_url=endpoint_url): |
|
|
headers = { |
|
|
"Authorization": f"Bearer {hf_api_key}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
data = {"inputs": inputs} |
|
|
if parameters is not None: |
|
|
data.update({"parameters": parameters}) |
|
|
response = requests.post(endpoint_url, headers=headers, data=json.dumps(data)) |
|
|
return json.loads(response.content.decode("utf-8")) |
|
|
|
|
|
def get_generation(model, processor, image, dtype): |
|
|
inputs = processor(image, return_tensors="pt").to(dtype) |
|
|
out = model.generate(**inputs) |
|
|
return processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
def load_image(img_url): |
|
|
image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB') |
|
|
return image |
|
|
|
|
|
|
|
|
def caption_image(image_url): |
|
|
|
|
|
response = requests.get(image_url) |
|
|
response.raise_for_status() |
|
|
image = Image.open(BytesIO(response.content)) |
|
|
|
|
|
|
|
|
|
|
|
caption = get_completion(image_url) |
|
|
return caption |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=caption_image, |
|
|
inputs=gr.Textbox(label="Image URL"), |
|
|
outputs="text", |
|
|
|
|
|
|
|
|
title="Image Captioning App", |
|
|
description=( |
|
|
"Upload an image or use one of the predefined samples to generate a caption. " |
|
|
"This app uses a Hugging Face Inference Endpoint for the `Salesforce/blip-image-captioning-base` model." |
|
|
), |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|