| from io import BytesIO |
|
|
| import string |
| import gradio as gr |
| import requests |
| from utils import Endpoint |
|
|
|
|
| def encode_image(image): |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| buffered.seek(0) |
|
|
| return buffered |
|
|
|
|
| def query_api( |
| image, prompt, decoding_method, temperature, len_penalty, repetition_penalty |
| ): |
|
|
| url = endpoint.url |
|
|
| headers = {"User-Agent": "BLIP-2 HuggingFace Space"} |
|
|
| data = { |
| "prompt": prompt, |
| "use_nucleus_sampling": decoding_method == "Nucleus sampling", |
| "temperature": temperature, |
| "length_penalty": len_penalty, |
| "repetition_penalty": repetition_penalty, |
| } |
|
|
| image = encode_image(image) |
| files = {"image": image} |
|
|
| response = requests.post(url, data=data, files=files, headers=headers) |
|
|
| if response.status_code == 200: |
| return response.json() |
| else: |
| return "Error: " + response.text |
|
|
|
|
| def postprocess_output(output): |
| |
| if not output[0][-1] in string.punctuation: |
| output[0] += "." |
|
|
| return output |
|
|
|
|
| def inference( |
| image, |
| text_input, |
| decoding_method, |
| temperature, |
| length_penalty, |
| repetition_penalty, |
| history=[], |
| ): |
| text_input = text_input |
| history.append(text_input) |
|
|
| prompt = " ".join(history) |
| print(prompt) |
|
|
| output = query_api( |
| image, prompt, decoding_method, temperature, length_penalty, repetition_penalty |
| ) |
| output = postprocess_output(output) |
| history += output |
|
|
| chat = [ |
| (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) |
| ] |
|
|
| return {chatbot: chat, state: history} |
|
|
|
|
| title = """<h1 align="center">BLIP-2</h1>""" |
| description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> |
| <p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>""" |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>" |
|
|
| endpoint = Endpoint() |
|
|
| examples = [ |
| ["house.png", "How could someone get out of the house?"], |
| |
| |
| |
| |
| ] |
|
|
| with gr.Blocks() as iface: |
| state = gr.State([]) |
|
|
| gr.Markdown(title) |
| gr.Markdown(description) |
| gr.Markdown(article) |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil") |
| text_input = gr.Textbox(lines=2, label="Text input") |
|
|
| sampling = gr.Radio( |
| choices=["Beam search", "Nucleus sampling"], |
| value="Beam search", |
| label="Text Decoding Method", |
| interactive=True, |
| ) |
|
|
| with gr.Row(): |
| temperature = gr.Slider( |
| minimum=0.5, |
| maximum=1.0, |
| value=0.8, |
| interactive=True, |
| label="Temperature", |
| ) |
|
|
| len_penalty = gr.Slider( |
| minimum=-2.0, |
| maximum=2.0, |
| value=1.0, |
| step=0.5, |
| interactive=True, |
| label="Length Penalty", |
| ) |
|
|
| rep_penalty = gr.Slider( |
| minimum=1.0, |
| maximum=20.0, |
| value=10.0, |
| step=0.5, |
| interactive=True, |
| label="Repetition Penalty", |
| ) |
|
|
| with gr.Column(): |
| with gr.Row(): |
| chatbot = gr.Chatbot() |
| image_input.change(lambda: (None, []), [], [chatbot, state]) |
|
|
| with gr.Row(): |
|
|
| clear_button = gr.Button(value="Clear", interactive=True) |
| clear_button.click( |
| lambda: ("", None, [], []), |
| [], |
| [text_input, image_input, chatbot, state], |
| ) |
|
|
| submit_button = gr.Button( |
| value="Submit", interactive=True, variant="primary" |
| ) |
| submit_button.click( |
| inference, |
| [ |
| image_input, |
| text_input, |
| sampling, |
| temperature, |
| len_penalty, |
| rep_penalty, |
| state, |
| ], |
| [chatbot, state], |
| ) |
|
|
| examples = gr.Examples( |
| examples=examples, |
| inputs=[image_input, text_input], |
| ) |
|
|
| iface.queue(concurrency_count=1, api_open=False, max_size=20) |
| iface.launch(enable_queue=True) |
|
|