Spaces:
Runtime error
Runtime error
| import os | |
| import string | |
| import gradio as gr | |
| import PIL.Image | |
| import torch | |
| from transformers import BitsAndBytesConfig, pipeline | |
| import re | |
| DESCRIPTION = "# LLaVA 🌋" | |
| model_id = "llava-hf/llava-1.5-7b-hf" | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
| DESCRIPTION = "LLaVA is now available in transformers!" | |
| def extract_response_pairs(text): | |
| pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL) | |
| matches = pattern.findall(text) | |
| print(matches) | |
| pairs = [(user.strip(), assistant.strip()) for user, assistant in matches] | |
| return pairs | |
| def postprocess_output(output: str) -> str: | |
| if output and output[-1] not in string.punctuation: | |
| output += "." | |
| return output | |
| def chat(image, text, max_length, history_chat): | |
| prompt = " ".join(history_chat) + f"USER: <image>\n{text}\nASSISTANT:" | |
| outputs = pipe(image, prompt=prompt, | |
| generate_kwargs={ | |
| "max_length":max_length}) | |
| #output = postprocess_output(outputs[0]["generated_text"]) | |
| history_chat.append(outputs[0]["generated_text"]) | |
| chat_val = extract_response_pairs(" ".join(history_chat)) | |
| return chat_val, history_chat | |
| css = """ | |
| #mkd { | |
| height: 500px; | |
| overflow: auto; | |
| border: 1px solid #ccc; | |
| } | |
| """ | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| chatbot = gr.Chatbot(label="Chat", show_label=False) | |
| with gr.Row(): | |
| image = gr.Image(type="pil") | |
| text_input = gr.Text(label="Chat Input", show_label=False, max_lines=1, container=False) | |
| history_chat = gr.State(value=[]) | |
| with gr.Row(): | |
| clear_chat_button = gr.Button("Clear") | |
| chat_button = gr.Button("Submit", variant="primary") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| max_length = gr.Slider( | |
| label="Max Length", | |
| minimum=1, | |
| maximum=200, | |
| step=1, | |
| value=100, | |
| ) | |
| chat_output = [ | |
| chatbot, | |
| history_chat | |
| ] | |
| chat_button.click(fn=chat, inputs=[image, | |
| text_input, | |
| max_length, | |
| history_chat], | |
| outputs=chat_output, | |
| api_name="Chat", | |
| ) | |
| chat_inputs = [ | |
| image, | |
| text_input, | |
| max_length, | |
| history_chat | |
| ] | |
| text_input.submit( | |
| fn=chat, | |
| inputs=chat_inputs, | |
| outputs=chat_output | |
| ).success( | |
| fn=lambda: "", | |
| outputs=chat_inputs, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| clear_chat_button.click( | |
| fn=lambda: ([], []), | |
| inputs=None, | |
| outputs=[ | |
| chatbot, | |
| history_chat | |
| ], | |
| queue=False, | |
| api_name="clear", | |
| ) | |
| image.change( | |
| fn=lambda: ([], []), | |
| inputs=None, | |
| outputs=[ | |
| chatbot, | |
| history_chat | |
| ], | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(debug=True) |