Spaces:
Paused
Paused
| # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors. | |
| # Based on https://github.com/haotian-liu/LLaVA. | |
| """ | |
| This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module. | |
| main process: CLI server I/O, LLM inference | |
| process-1: logger listener | |
| process-2: frame generator, | |
| process-3: frame memory manager | |
| Author: Haoji Zhang, Haotian Liu | |
| (This code is based on https://github.com/haotian-liu/LLaVA) | |
| """ | |
| import argparse | |
| import requests | |
| import logging | |
| import torch | |
| import numpy as np | |
| import time | |
| import os | |
| from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
| from flash_vstream.conversation import conv_templates, SeparatorStyle | |
| from flash_vstream.model.builder import load_pretrained_model | |
| from flash_vstream.utils import disable_torch_init | |
| from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | |
| from torch.multiprocessing import Process, Queue, Manager | |
| from transformers import TextStreamer | |
| from decord import VideoReader | |
| from datetime import datetime | |
| from PIL import Image | |
| from io import BytesIO | |
| class _Metric: | |
| def __init__(self): | |
| self._latest_value = None | |
| self._sum = 0.0 | |
| self._max = 0.0 | |
| self._count = 0 | |
| def val(self): | |
| return self._latest_value | |
| def max(self): | |
| return self._max | |
| def avg(self): | |
| if self._count == 0: | |
| return float('nan') | |
| return self._sum / self._count | |
| def add(self, value): | |
| self._latest_value = value | |
| self._sum += value | |
| self._count += 1 | |
| if value > self._max: | |
| self._max = value | |
| def __str__(self): | |
| latest_formatted = f"{self.val:.6f}" if self.val is not None else "None" | |
| average_formatted = f"{self.avg:.6f}" | |
| max_formatted = f"{self.max:.6f}" | |
| return f"{latest_formatted} ({average_formatted}, {max_formatted})" | |
| class MetricMeter: | |
| def __init__(self): | |
| self._metrics = {} | |
| def add(self, key, value): | |
| if key not in self._metrics: | |
| self._metrics[key] = _Metric() | |
| self._metrics[key].add(value) | |
| def val(self, key): | |
| metric = self._metrics.get(key) | |
| if metric is None or metric.val is None: | |
| raise ValueError(f"No values have been added for key '{key}'.") | |
| return metric.val | |
| def avg(self, key): | |
| metric = self._metrics.get(key) | |
| if metric is None: | |
| raise ValueError(f"No values have been added for key '{key}'.") | |
| return metric.avg | |
| def max(self, key): | |
| metric = self._metrics.get(key) | |
| if metric is None: | |
| raise ValueError(f"No values have been added for key '{key}'.") | |
| return metric.max | |
| def __getitem__(self, key): | |
| metric = self._metrics.get(key) | |
| if metric is None: | |
| raise KeyError(f"The key '{key}' does not exist.") | |
| return str(metric) | |
| def load_image(image_file): | |
| if image_file.startswith('http://') or image_file.startswith('https://'): | |
| response = requests.get(image_file) | |
| image = Image.open(BytesIO(response.content)).convert('RGB') | |
| else: | |
| image = Image.open(image_file).convert('RGB') | |
| return image | |
| def listener(queue, filename): | |
| ############## Start sub process-1: Listener ############# | |
| import sys, traceback | |
| root = logging.getLogger() | |
| root.setLevel(logging.DEBUG) | |
| # h = logging.StreamHandler(sys.stdout) | |
| h = logging.FileHandler(filename) | |
| f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s') | |
| h.setFormatter(f) | |
| root.addHandler(h) | |
| while True: | |
| try: | |
| record = queue.get() | |
| if record is None: # None is a signal to finish | |
| break | |
| logger = logging.getLogger(record.name) | |
| logger.handle(record) # No level or filter logic applied - just do it! | |
| except Exception: | |
| import sys, traceback | |
| print('Whoops! Problem:', file=sys.stderr) | |
| traceback.print_exc(file=sys.stderr) | |
| def worker_configurer(queue): | |
| h = logging.handlers.QueueHandler(queue) # Just the one handler needed | |
| root = logging.getLogger() | |
| root.addHandler(h) | |
| root.setLevel(logging.DEBUG) | |
| def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0): | |
| ############## Start sub process-2: Simulator ############# | |
| worker_configurer(log_queue) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| vr = VideoReader(video_file) | |
| sample_fps = round(vr.get_avg_fps() / video_fps) | |
| frame_idx = [i for i in range(0, len(vr), sample_fps)] | |
| video = vr.get_batch(frame_idx).asnumpy() | |
| video = np.repeat(video, 6, axis=0) | |
| length = video.shape[0] | |
| sleep_time = 1 / video_fps / play_speed | |
| time_meter = MetricMeter() | |
| logger.info(f'Simulator Process: start, length = {length}') | |
| try: | |
| for start in range(0, length): | |
| start_time = time.perf_counter() | |
| end = min(start + 1, length) | |
| video_clip = video[start:end] | |
| frame_queue.put(video_clip) | |
| if start > 0: | |
| time_meter.add('real_sleep', start_time - last_start) | |
| logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}') | |
| if end < length: | |
| time.sleep(sleep_time) | |
| last_start = start_time | |
| frame_queue.put(None) | |
| except Exception as e: | |
| print(f'Simulator Exception: {e}') | |
| time.sleep(0.1) | |
| logger.info(f'Simulator Process: end') | |
| def frame_memory_manager(model, image_processor, frame_queue, log_queue): | |
| ############## Start sub process-3: Memory Manager ############# | |
| worker_configurer(log_queue) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| time_meter = MetricMeter() | |
| logger.info(f'MemManager Process: start') | |
| frame_cnt = 0 | |
| while True: | |
| try: | |
| video_clip = frame_queue.get() | |
| start_time = time.perf_counter() | |
| if video_clip is None: | |
| logger.info(f'MemManager: Ooops, get None') | |
| break | |
| logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue') | |
| image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values'] | |
| image = image.unsqueeze(0) | |
| image_tensor = image.to(model.device, dtype=torch.float16) | |
| # time_2 = time.perf_counter() | |
| logger.info(f'MemManager: Start embedding') | |
| with torch.inference_mode(): | |
| model.embed_video_streaming(image_tensor) | |
| logger.info(f'MemManager: End embedding') | |
| end_time = time.perf_counter() | |
| if frame_cnt > 0: | |
| time_meter.add('memory_latency', end_time - start_time) | |
| logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}') | |
| else: | |
| logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged') | |
| frame_cnt += video_clip.shape[0] | |
| except Exception as e: | |
| print(f'MemManager Exception: {e}') | |
| time.sleep(0.1) | |
| logger.info(f'MemManager Process: end') | |
| def main(args): | |
| # torch.multiprocessing.log_to_stderr(logging.DEBUG) | |
| torch.multiprocessing.set_start_method('spawn', force=True) | |
| disable_torch_init() | |
| log_queue = Queue() | |
| frame_queue = Queue(maxsize=10) | |
| processes = [] | |
| ############## Start listener process ############# | |
| p1 = Process(target=listener, args=(log_queue, args.log_file)) | |
| processes.append(p1) | |
| p1.start() | |
| ############## Start main process ############# | |
| worker_configurer(log_queue) | |
| logger = logging.getLogger(__name__) | |
| model_name = get_model_name_from_path(args.model_path) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) | |
| logger.info(f'Using conv_mode={args.conv_mode}') | |
| conv = conv_templates[args.conv_mode].copy() | |
| if "mpt" in model_name.lower(): | |
| roles = ('user', 'assistant') | |
| else: | |
| roles = conv.roles | |
| with Manager() as manager: | |
| image_tensor = None | |
| model.use_video_streaming_mode = True | |
| model.video_embedding_memory = manager.list() | |
| if args.video_max_frames is not None: | |
| model.config.video_max_frames = args.video_max_frames | |
| logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}') | |
| logger.info(f'Important: set video_fps = {args.video_fps}') | |
| logger.info(f'Important: set play_speed = {args.play_speed}') | |
| ############## Start simulator process ############# | |
| p2 = Process(target=video_stream_similator, | |
| args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed)) | |
| processes.append(p2) | |
| p2.start() | |
| ############## Start memory manager process ############# | |
| p3 = Process(target=frame_memory_manager, | |
| args=(model, image_processor, frame_queue, log_queue)) | |
| processes.append(p3) | |
| p3.start() | |
| # start QA server | |
| start_time = datetime.now() | |
| time_meter = MetricMeter() | |
| conv_cnt = 0 | |
| while True: | |
| time.sleep(5) | |
| try: | |
| # inp = input(f"{roles[0]}: ") | |
| inp = "what is in the video?" | |
| except EOFError: | |
| inp = "" | |
| if not inp: | |
| print("exit...") | |
| break | |
| # 获取当前时间 | |
| now = datetime.now() | |
| conv_start_time = time.perf_counter() | |
| # 将当前时间格式化为字符串 | |
| current_time = now.strftime("%H:%M:%S") | |
| duration = now.timestamp() - start_time.timestamp() | |
| # 打印当前时间 | |
| print("\nCurrent Time:", current_time, "Run for:", duration) | |
| print(f"{roles[0]}: {inp}", end="\n") | |
| print(f"{roles[1]}: ", end="") | |
| # every conversation is a new conversation | |
| conv = conv_templates[args.conv_mode].copy() | |
| inp = DEFAULT_IMAGE_TOKEN + '\n' + inp | |
| conv.append_message(conv.roles[0], inp) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| llm_start_time = time.perf_counter() | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=True if args.temperature > 0 else False, | |
| temperature=args.temperature, | |
| max_new_tokens=args.max_new_tokens, | |
| streamer=streamer, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria] | |
| ) | |
| llm_end_time = time.perf_counter() | |
| outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() | |
| conv.messages[-1][-1] = outputs | |
| conv_end_time = time.perf_counter() | |
| if conv_cnt > 0: | |
| time_meter.add('conv_latency', conv_end_time - conv_start_time) | |
| time_meter.add('llm_latency', llm_end_time - llm_start_time) | |
| time_meter.add('real_sleep', conv_start_time - last_conv_start_time) | |
| logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}') | |
| else: | |
| logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}') | |
| conv_cnt += 1 | |
| last_conv_start_time = conv_start_time | |
| for p in processes: | |
| p.terminate() | |
| print("All processes finished.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model-path", type=str, default="facebook/opt-350m") | |
| parser.add_argument("--model-base", type=str, default=None) | |
| parser.add_argument("--image-file", type=str, default=None) | |
| parser.add_argument("--video-file", type=str, default=None) | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--conv-mode", type=str, default="vicuna_v1") | |
| parser.add_argument("--temperature", type=float, default=0.2) | |
| parser.add_argument("--max-new-tokens", type=int, default=512) | |
| parser.add_argument("--load-8bit", action="store_true") | |
| parser.add_argument("--load-4bit", action="store_true") | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--log-file", type=str, default="tmp_cli.log") | |
| parser.add_argument("--use_1process", action="store_true") | |
| parser.add_argument("--video_max_frames", type=int, default=None) | |
| parser.add_argument("--video_fps", type=float, default=1.0) | |
| parser.add_argument("--play_speed", type=float, default=1.0) | |
| args = parser.parse_args() | |
| main(args) | |