| """ |
| srun -p INTERN2 --job-name='husky_multi_test' --gres=gpu:1 --cpus-per-task=8 --quotatype="auto" python -u demo/inference_new.py |
| """ |
|
|
| import abc |
| from typing import Optional |
|
|
| import os |
| import requests |
| from PIL import Image |
| from io import BytesIO |
|
|
| import torch |
| import torchvision.transforms as T |
| from peft import PeftModel |
| from torchvision.transforms.functional import InterpolationMode |
|
|
| from transformers import ( |
| LlamaTokenizer, |
| GenerationConfig, |
| StoppingCriteria, |
| StoppingCriteriaList, |
| ) |
|
|
| from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration |
|
|
| from robohusky.conversation import ( |
| conv_templates, |
| get_conv_template, |
| ) |
|
|
| from robohusky.video_transformers import ( |
| GroupNormalize, |
| GroupScale, |
| GroupCenterCrop, |
| Stack, |
| ToTorchFormatTensor, |
| get_index, |
| ) |
|
|
| from robohusky.compression import compress_module |
| from decord import VideoReader, cpu |
|
|
| |
|
|
| IGNORE_INDEX = -100 |
| DEFAULT_UNK_TOKEN = "<unk>" |
| DEFAULT_IMG_START_TOKEN = "<img>" |
| DEFAULT_IMG_END_TOKEN = "</img>" |
|
|
| DEFAULT_VIDEO_START_TOKEN = "<vid>" |
| DEFAULT_VIDEO_END_TOKEN = "</vid>" |
|
|
| def get_gpu_memory(max_gpus=None): |
| gpu_memory = [] |
| num_gpus = ( |
| torch.cuda.device_count() |
| if max_gpus is None |
| else min(max_gpus, torch.cuda.device_count()) |
| ) |
|
|
| for gpu_id in range(num_gpus): |
| with torch.cuda.device(gpu_id): |
| device = torch.cuda.current_device() |
| gpu_properties = torch.cuda.get_device_properties(device) |
| total_memory = gpu_properties.total_memory / (1024 ** 3) |
| allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3) |
| available_memory = total_memory - allocated_memory |
| gpu_memory.append(available_memory) |
| return gpu_memory |
|
|
| def load_model( |
| model_path, device, num_gpus, max_gpu_memory=None, load_8bit=False, lora_weights=None |
| ): |
| if device == "cpu": |
| kwargs = {} |
| elif device == "cuda": |
| kwargs = {"torch_dtype": torch.float16} |
| if num_gpus == "auto": |
| kwargs["device_map"] = "auto" |
| else: |
| num_gpus = int(num_gpus) |
| if num_gpus != 1: |
| kwargs["device_map"] = "auto" |
| if max_gpu_memory is None: |
| kwargs[ |
| "device_map" |
| ] = "sequential" |
| available_gpu_memory = get_gpu_memory(num_gpus) |
| kwargs["max_memory"] = { |
| i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" |
| for i in range(num_gpus) |
| } |
| else: |
| kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} |
| else: |
| raise ValueError(f"Invalid device: {device}") |
|
|
| tokenizer = LlamaTokenizer.from_pretrained( |
| model_path, use_fast=False) |
|
|
| if lora_weights is None: |
| model = HuskyForConditionalGeneration.from_pretrained( |
| model_path, low_cpu_mem_usage=True, **kwargs |
| ) |
| else: |
| kwargs["device_map"] = "auto" |
| model = HuskyForConditionalGeneration.from_pretrained( |
| model_path, low_cpu_mem_usage=True, **kwargs |
| ) |
| model.language_model = PeftModel.from_pretrained( |
| model.language_model, |
| lora_weights, |
| **kwargs |
| ) |
|
|
| if load_8bit: |
| compress_module(model, device) |
|
|
| if (device == "cuda" and num_gpus == 1) or device == "mps": |
| model.to(device) |
|
|
| model = model.eval() |
| return model, tokenizer |
|
|
| def load_image(image_file, input_size=224): |
| if image_file.startswith('http') or image_file.startswith('https'): |
| response = requests.get(image_file) |
| image = Image.open(BytesIO(response.content)).convert('RGB') |
| else: |
| image = Image.open(image_file).convert('RGB') |
|
|
| crop_pct = 224 / 256 |
| size = int(input_size / crop_pct) |
| transform = T.Compose([ |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
| T.Resize(size, interpolation=InterpolationMode.BICUBIC), |
| T.CenterCrop(input_size), |
| T.ToTensor(), |
| T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
| ]) |
| image = transform(image) |
| return image |
|
|
| def load_video(video_path, num_segments=8): |
| vr = VideoReader(video_path, ctx=cpu(0)) |
| num_frames = len(vr) |
| frame_indices = get_index(num_frames, num_segments) |
|
|
| |
| crop_size = 224 |
| scale_size = 224 |
| input_mean = [0.48145466, 0.4578275, 0.40821073] |
| input_std = [0.26862954, 0.26130258, 0.27577711] |
|
|
| transform = T.Compose([ |
| GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), |
| GroupCenterCrop(crop_size), |
| Stack(), |
| ToTorchFormatTensor(), |
| GroupNormalize(input_mean, input_std) |
| ]) |
|
|
| images_group = list() |
| for frame_index in frame_indices: |
| img = Image.fromarray(vr[frame_index].asnumpy()) |
| images_group.append(img) |
| video = transform(images_group) |
| return video |
|
|
| class StoppingCriteriaSub(StoppingCriteria): |
|
|
| def __init__(self, stops, encounters=1): |
| super().__init__() |
| self.stops = stops |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): |
| for stop in self.stops: |
| if input_ids.shape[-1] < len(stop): |
| continue |
| |
| stop = stop.to(input_ids.device) |
| if torch.all((stop == input_ids[0][-len(stop):])).item(): |
| return True |
| return False |
|
|
|
|
|
|
| @torch.inference_mode() |
| def generate_stream( |
| model, tokenizer, image_processor, params, device |
| ): |
| prompt = params["prompt"] |
| images = params.get("images", None) |
| videos = params.get("videos", None) |
| temperature = float(params.get("temperature", 0.7)) |
| max_new_tokens = int(params.get("max_new_tokens", 1024)) |
|
|
| num_queries = model.config.num_query_tokens |
|
|
| stop_words = ["Human: ", "Assistant: ", "###", "\n\n"] |
| stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
| generation_config = GenerationConfig( |
| bos_token_id=1, |
| do_sample=True, |
| temperature=temperature, |
| max_new_tokens=max_new_tokens, |
| stopping_criteria=stopping_criteria |
| ) |
|
|
| pixel_values = None |
| if images is not None: |
| pixel_values = load_image(images).to(device) |
| image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN |
| prompt = prompt.replace("<image>", image_query) |
|
|
| elif videos is not None: |
| pixel_values = load_video(videos).to(device) |
| video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN |
| prompt = prompt.replace("<video>", video_query) |
|
|
| model_inputs = tokenizer([prompt], return_tensors="pt") |
| model_inputs.pop("token_type_ids", None) |
|
|
| if pixel_values is not None: |
| model_inputs["pixel_values"] = pixel_values |
|
|
| generation_output = model.generate( |
| **model_inputs, |
| generation_config=generation_config, |
| return_dict_in_generate=True, |
| output_scores=True |
| ) |
| else: |
| generation_output = model.language_model.generate( |
| **model_inputs, |
| generation_config=generation_config, |
| return_dict_in_generate=True, |
| output_scores=True |
| ) |
|
|
| preds = generation_output.sequences |
| outputs = tokenizer.batch_decode(preds, skip_special_tokens=True) |
| return outputs |
|
|
| class Chat: |
| def __init__( |
| self, |
| model_path, |
| device, |
| num_gpus=1, |
| load_8bit=False, |
| temperature=0.3, |
| max_new_tokens=512, |
| lora_path=None, |
| ): |
| model, tokenizer = load_model( |
| model_path, device, num_gpus, load_8bit=load_8bit, lora_weights=lora_path |
| ) |
|
|
| self.model = model |
| |
| |
| self.tokenizer = tokenizer |
| num_queries = model.config.num_query_tokens |
|
|
| self.device = device |
| self.dtype = model.dtype |
|
|
| stop_words = ["Human: ", "Assistant: ", "###", "\n\n"] |
| stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
| self.conv = get_conv_template("husky") |
|
|
| self.image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN |
| self.video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN |
|
|
| self.generation_config = GenerationConfig( |
| bos_token_id=1, |
| do_sample=True, |
| top_k=20, |
| top_p=0.9, |
| temperature=temperature, |
| max_new_tokens=max_new_tokens |
|
|
| ) |
| self.stopping_criteria = stopping_criteria |
| def ask(self, text, conv, modal_type="image"): |
| assert modal_type in ["text", "image", "video"] |
| conversations = [] |
|
|
| if len(conv.messages) > 0 or modal_type == "text": |
| conv.append_message(conv.roles[0], text) |
| elif modal_type == "image": |
| conv.append_message(conv.roles[0], self.image_query + "\n" + text) |
| else: |
| conv.append_message(conv.roles[0], self.video_query + "\n" + text) |
|
|
| conv.append_message(conv.roles[1], None) |
| conversations.append(conv.get_prompt()) |
| return conversations |
|
|
| @torch.no_grad() |
| def get_image_embedding(self, image_file): |
| pixel_values = load_image(image_file) |
| pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype) |
| language_model_inputs = self.model.extract_feature(pixel_values) |
| return language_model_inputs |
|
|
| @torch.no_grad() |
| def get_video_embedding(self, video_file): |
| pixel_values = load_video(video_file) |
| TC, H, W = pixel_values.shape |
| pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) |
| pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype) |
| assert len(pixel_values.shape) == 5 |
| language_model_inputs = self.model.extract_feature(pixel_values) |
| return language_model_inputs |
|
|
| @torch.no_grad() |
| def answer(self, conversations, language_model_inputs, modal_type="image"): |
| model_inputs = self.tokenizer( |
| conversations, |
| return_tensors="pt", |
| ) |
| model_inputs.pop("token_type_ids", None) |
|
|
| input_ids = model_inputs["input_ids"].to(self.device) |
| attention_mask = model_inputs["attention_mask"].to(self.device) |
|
|
| if modal_type == "text": |
| generation_output = self.model.language_model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| generation_config=self.generation_config, |
| stopping_criteria=self.stopping_criteria, |
| return_dict_in_generate=True, |
| output_scores=True |
| ) |
| else: |
| pixel_values = model_inputs.pop("pixel_values", None) |
| if pixel_values is not None: |
| pixel_values = pixel_values.to(self.device) |
|
|
| generation_output = self.model.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| language_model_inputs=language_model_inputs, |
| generation_config=self.generation_config, |
| stopping_criteria=self.stopping_criteria, |
| return_dict_in_generate=True, |
| output_scores=True |
| ) |
|
|
| preds = generation_output.sequences |
| outputs = self.tokenizer.batch_decode(preds, skip_special_tokens=True)[0] |
|
|
| if modal_type == "text": |
| skip_echo_len = len(conversations[0]) - conversations[0].count("</s>") * 3 |
| outputs = outputs[skip_echo_len:].strip() |
|
|
| return outputs |
|
|
| if __name__ == '__main__': |
| |
| model_path = "./" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| chat = Chat(model_path, device=device, num_gpus=1, max_new_tokens=1024, load_8bit=False) |
|
|
| vision_feature = None |
| image_state = False |
| video_state = False |
|
|
| while True: |
| query = input("\n") |
| if query.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): |
| if os.path.exists(query): |
| print("received.") |
| vision_feature = chat.get_image_embedding(query) |
| chat.conv = get_conv_template("husky").copy() |
| image_state = True |
| continue |
| if query.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")): |
| if os.path.exists(query): |
| print("received.") |
| vision_feature = chat.get_video_embedding(query) |
| chat.conv = get_conv_template("husky").copy() |
| video_state = True |
| continue |
|
|
| if query == "stop": |
| break |
| if query == "clear" or query == "" or query == "\n": |
| chat.conv = get_conv_template("husky").copy() |
| image_state = False |
| video_state = False |
| os.system("clear") |
| print("欢迎使用 husky-13b-zh 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") |
| continue |
|
|
| if image_state: |
| modal_type = "image" |
| elif video_state: |
| modal_type = "video" |
| else: |
| modal_type = "text" |
|
|
| |
| |
| |
| |
| |
| conversations = chat.ask(text=query, conv=chat.conv, modal_type=modal_type) |
| outputs = chat.answer(conversations, vision_feature, modal_type=modal_type) |
| |
| chat.conv.messages[-1][1] = outputs.strip() |
|
|
| print(f"Husky: \n{outputs}") |
|
|