Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import time | |
| from PIL import Image | |
| from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration | |
| from typing import List | |
| processor = MLlavaProcessor.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096") | |
| model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096") | |
| def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs): | |
| global processor, model | |
| model = model.to("cuda") | |
| if not images: | |
| images = None | |
| for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs): | |
| yield text | |
| return text | |
| def enable_next_image(uploaded_images, image): | |
| uploaded_images.append(image) | |
| return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False) | |
| def add_message(history, message): | |
| if message["files"]: | |
| for file in message["files"]: | |
| history.append([(file,), None]) | |
| if message["text"]: | |
| history.append([message["text"], None]) | |
| return history, gr.MultimodalTextbox(value=None) | |
| def print_like_dislike(x: gr.LikeData): | |
| print(x.index, x.value, x.liked) | |
| def get_chat_history(history): | |
| chat_history = [] | |
| for i, message in enumerate(history): | |
| if isinstance(message[0], str): | |
| chat_history.append({"role": "user", "text": message[0]}) | |
| if i != len(history) - 1: | |
| assert message[1], "The bot message is not provided, internal error" | |
| chat_history.append({"role": "assistant", "text": message[1]}) | |
| else: | |
| assert not message[1], "the bot message internal error, get: {}".format(message[1]) | |
| chat_history.append({"role": "assistant", "text": ""}) | |
| return chat_history | |
| def get_chat_images(history): | |
| images = [] | |
| for message in history: | |
| if isinstance(message[0], tuple): | |
| images.extend(message[0]) | |
| return images | |
| def bot(history): | |
| print(history) | |
| cur_messages = {"text": "", "images": []} | |
| for message in history[::-1]: | |
| if message[1]: | |
| break | |
| if isinstance(message[0], str): | |
| cur_messages["text"] = message[0] + " " + cur_messages["text"] | |
| elif isinstance(message[0], tuple): | |
| cur_messages["images"].extend(message[0]) | |
| cur_messages["text"] = cur_messages["text"].strip() | |
| cur_messages["images"] = cur_messages["images"][::-1] | |
| if not cur_messages["text"]: | |
| raise gr.Error("Please enter a message") | |
| if cur_messages['text'].count("<image>") < len(cur_messages['images']): | |
| gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.") | |
| cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text'] | |
| history[-1][0] = cur_messages["text"] | |
| if cur_messages['text'].count("<image>") > len(cur_messages['images']): | |
| gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.") | |
| cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1] | |
| history[-1][0] = cur_messages["text"] | |
| chat_history = get_chat_history(history) | |
| chat_images = get_chat_images(history) | |
| generation_kwargs = { | |
| "max_new_tokens": 4096, | |
| "temperature": 0.7, | |
| "top_p": 1.0, | |
| "do_sample": True, | |
| } | |
| print(None, chat_images, chat_history, generation_kwargs) | |
| response = generate(None, chat_images, chat_history, **generation_kwargs) | |
| for _output in response: | |
| history[-1][1] = _output | |
| time.sleep(0.05) | |
| yield history | |
| def build_demo(): | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot(line_breaks=True) | |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True) | |
| chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input]) | |
| bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response") | |
| chatbot.like(print_like_dislike, None, None) | |
| with gr.Row(): | |
| send_button = gr.Button("Send") | |
| clear_button = gr.ClearButton([chatbot, chat_input]) | |
| send_button.click( | |
| add_message, [chatbot, chat_input], [chatbot, chat_input] | |
| ).then( | |
| bot, chatbot, chatbot, api_name="bot_response" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() |