import argparse from threading import Thread import gradio as gr from PIL import Image from src.utils import load_pretrained_model, get_model_name_from_path, disable_torch_init from transformers import TextIteratorStreamer from functools import partial import warnings from qwen_vl_utils import process_vision_info warnings.filterwarnings("ignore") def is_video_file(filename): video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'] return any(filename.lower().endswith(ext) for ext in video_extensions) def bot_streaming(message, history, generation_args): # Initialize variables images = [] videos = [] if message["files"]: for file_item in message["files"]: if isinstance(file_item, dict): file_path = file_item["path"] else: file_path = file_item if is_video_file(file_path): videos.append(file_path) else: images.append(file_path) conversation = [] for user_turn, assistant_turn in history: user_content = [] if isinstance(user_turn, tuple): file_paths = user_turn[0] user_text = user_turn[1] if not isinstance(file_paths, list): file_paths = [file_paths] for file_path in file_paths: if is_video_file(file_path): user_content.append({"type": "video", "video": file_path, "fps":1.0}) else: user_content.append({"type": "image", "image": file_path}) if user_text: user_content.append({"type": "text", "text": user_text}) else: user_content.append({"type": "text", "text": user_turn}) conversation.append({"role": "user", "content": user_content}) if assistant_turn is not None: assistant_content = [{"type": "text", "text": assistant_turn}] conversation.append({"role": "assistant", "content": assistant_content}) user_content = [] for image in images: user_content.append({"type": "image", "image": image}) for video in videos: user_content.append({"type": "video", "video": video, "fps":1.0}) user_text = message['text'] if user_text: user_content.append({"type": "text", "text": user_text}) conversation.append({"role": "user", "content": user_content}) prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(conversation) inputs = processor(text=[prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor.tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,}) generation_kwargs = dict(inputs, streamer=streamer, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer def main(args): global processor, model, device device = args.device disable_torch_init() use_flash_attn = True model_name = get_model_name_from_path(args.model_path) if args.disable_flash_attention: use_flash_attn = False processor, model = load_pretrained_model(model_base = args.model_base, model_path = args.model_path, device_map=args.device, model_name=model_name, load_4bit=args.load_4bit, load_8bit=args.load_8bit, device=args.device, use_flash_attn=use_flash_attn ) chatbot = gr.Chatbot(scale=2) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image", "video"], placeholder="Enter message or upload file...", show_label=False) generation_args = { "max_new_tokens": args.max_new_tokens, "temperature": args.temperature, "do_sample": True if args.temperature > 0 else False, "repetition_penalty": args.repetition_penalty, } bot_streaming_with_args = partial(bot_streaming, generation_args=generation_args) with gr.Blocks(fill_height=True) as demo: gr.ChatInterface( fn=bot_streaming_with_args, title="Qwen2-VL-7B Instruct", stop_btn="Stop Generation", multimodal=True, textbox=chat_input, chatbot=chatbot, ) demo.queue(api_open=False) demo.launch(show_api=False, share=False, server_name='0.0.0.0') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default=None) parser.add_argument("--model-base", type=str, default="Qwen/Qwen2-VL-7B-Instruct") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--load-8bit", action="store_true") parser.add_argument("--load-4bit", action="store_true") parser.add_argument("--disable_flash_attention", action="store_true") parser.add_argument("--temperature", type=float, default=0) parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--max-new-tokens", type=int, default=1024) parser.add_argument("--debug", action="store_true") args = parser.parse_args() main(args)