Spaces:
Runtime error
Runtime error
| from threading import Thread | |
| from llava_llama3.serve.cli import chat_llava | |
| from llava_llama3.model.builder import load_pretrained_model | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import argparse | |
| import spaces | |
| import os | |
| import time | |
| root_path = os.path.dirname(os.path.abspath(__file__)) | |
| print(root_path) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model-path", type=str, default="TheFinAI/FinLLaVA") | |
| parser.add_argument("--device", type=str, default="cuda:0") | |
| parser.add_argument("--conv-mode", type=str, default="llama_3") | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--max-new-tokens", type=int, default=512) | |
| parser.add_argument("--load-8bit", action="store_true") | |
| parser.add_argument("--load-4bit", action="store_true") | |
| args = parser.parse_args() | |
| # load model | |
| tokenizer, llava_model, image_processor, context_len = load_pretrained_model( | |
| args.model_path, | |
| None, | |
| 'llava_llama3', | |
| args.load_8bit, | |
| args.load_4bit, | |
| device=args.device | |
| ) | |
| def bot_streaming(message, history): | |
| print(message) | |
| image_path = None | |
| # Check if there's an image in the current message | |
| if message["files"]: | |
| # message["files"][-1] could be a dictionary or a string | |
| if isinstance(message["files"][-1], dict): | |
| image_path = message["files"][-1]["path"] | |
| else: | |
| image_path = message["files"][-1] | |
| else: | |
| # If no image in the current message, look in the history for the last image path | |
| for hist in history: | |
| if isinstance(hist[0], tuple): | |
| image_path = hist[0][0] | |
| # Error handling if no image path is found | |
| if image_path is None: | |
| raise gr.Error("You need to upload an image for LLaVA to work.") | |
| # If the image_path is a string, no need to load it into a PIL image | |
| # Just use the path directly in the next steps | |
| print(f"\033[91m{image_path}, {type(image_path)}\033[0m") | |
| # Generate the prompt for the model | |
| prompt = message['text'] | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Set up the generation arguments, including the streamer | |
| generation_kwargs = dict( | |
| args=args, | |
| image_file=image_path, | |
| text=prompt, | |
| tokenizer=tokenizer, | |
| model=llava_model, | |
| streamer=streamer | |
| image_processor=image_processor, # todo: input model name or path | |
| context_len=context_len) | |
| # Define the function to call `chat_llava` with the given arguments | |
| def generate_output(generation_kwargs): | |
| chat_llava(**generation_kwargs) | |
| # Start the generation in a separate thread | |
| thread = Thread(target=generate_output, kwargs=generation_kwargs) | |
| thread.start() | |
| # Initialize a buffer to accumulate the generated text | |
| buffer = "" | |
| # Allow the generation to start | |
| time.sleep(0.5) | |
| # Iterate over the streamer to handle the incoming text in chunks | |
| for new_text in streamer: | |
| # Look for the end of text token and remove it | |
| if "<|eot_id|>" in new_text: | |
| new_text = new_text.split("<|eot_id|>")[0] | |
| # Add the new text to the buffer | |
| buffer += new_text | |
| # Remove the prompt from the generated text (if necessary) | |
| generated_text_without_prompt = buffer[len(prompt):] | |
| # Simulate processing time (optional) | |
| time.sleep(0.06) | |
| # Yield the current generated text for further processing or display | |
| yield generated_text_without_prompt | |
| chatbot = gr.Chatbot(scale=1) | |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.ChatInterface( | |
| fn=bot_streaming, | |
| title="FinLLaVA", | |
| examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]}, | |
| {"text": "How to make this pastry?", "files": ["./baklava.png"]}, | |
| {"text":"What is this?","files":["http://images.cocodataset.org/val2017/000000039769.jpg"]}], | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| textbox=chat_input, | |
| chatbot=chatbot, | |
| ) | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False, share=False) |