import os from io import BytesIO import requests import torchvision.transforms as T from PIL import Image from FlowSteeringWorm.llava.conversation import conv_templates from FlowSteeringWorm.llava.model import * from transformers import AutoTokenizer from transformers import CLIPVisionModel, CLIPImageProcessor transform = T.ToPILImage() import torch import numpy as np import warnings warnings.filterwarnings("ignore") torch.manual_seed(42) from transformers import logging logging.set_verbosity_error() SEED = 10 if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) np.random.seed(SEED) TEMPERATURE = 0.1 MAX_NEW_TOKENS = 1024 CONTEXT_LEN = 2048 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" from utils.encryption import encrypt_data, decrypt_data class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized image. """ tensor = tensor.clone() for t, m, s in zip(tensor, self.mean, self.std): t.mul_(s).add_(m) # The normalize code -> t.sub_(m).div_(s) return tensor class Normalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized image. """ tensor = tensor.clone() for t, m, s in zip(tensor, self.mean, self.std): t.sub_(m).div_(s) return tensor def load_image(image_file): if image_file.startswith('http') or image_file.startswith('https'): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert('RGB') else: image = Image.open(image_file).convert('RGB') return image def generate_stream(model, prompt, tokenizer, input_ids, images=None): temperature = TEMPERATURE max_new_tokens = MAX_NEW_TOKENS context_len = CONTEXT_LEN max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] stop_idx = 2 ori_prompt = prompt image_args = {"images": images} output_ids = list(input_ids) pred_ids = [] max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] past_key_values = None for i in range(max_new_tokens): if i == 0 and past_key_values is None: out = model( torch.as_tensor([input_ids]).cuda(), use_cache=True, output_hidden_states=True, **image_args, ) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( 1, past_key_values[0][0].shape[-2] + 1, device="cuda" ) out = model( input_ids=torch.as_tensor([[token]], device="cuda"), use_cache=True, attention_mask=attention_mask, past_key_values=past_key_values, output_hidden_states=True, ) logits = out.logits past_key_values = out.past_key_values # yield out last_token_logits = logits[0][-1] if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: probs = torch.softmax(last_token_logits / temperature, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) pred_ids.append(token) if stop_idx is not None and token == stop_idx: stopped = True elif token == tokenizer.eos_token_id: stopped = True else: stopped = False if i != 0 and i % 1024 == 0 or i == max_new_tokens - 1 or stopped: cur_out = tokenizer.decode(pred_ids, skip_special_tokens=True) pos = -1 # cur_out.rfind(stop_str) if pos != -1: cur_out = cur_out[:pos] stopped = True output = ori_prompt + cur_out # print('output', output) ret = { "text": output, "error_code": 0, } yield cur_out if stopped: break if past_key_values is not None: del past_key_values def run_result(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor): device = 'cuda' X = load_image(X) print("Image: ") # load the image X = image_processor.preprocess(X, return_tensors='pt')['pixel_values'][0].unsqueeze(0).half().cuda() # Generate the output with initial query input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device=device) res = generate_stream(model, prompt, tokenizer, input_ids[0].tolist(), X) for response1 in res: outputs1 = response1 print(f'Query 1:') print(initial_query) print(f'Response 1:') print(outputs1.strip()) print('********') ALLResponses = [] ALLResponses.append(outputs1.strip()) # Generate the outputs with further queries for idx, query in enumerate(query_list): if idx == 0: # Update current prompt with the initial prompt and first output new_prompt = prompt + outputs1 + "\n###Human: " + query + "\n###Assistant:" else: # Update current prompt with the previous prompt and latest output new_prompt = ( new_prompt + outputs + "\n###Human: " + query + "\n###Assistant:" ) input_ids = tokenizer.encode(new_prompt, return_tensors="pt").cuda() # Generate the response using the updated prompt res = generate_stream(model, new_prompt, tokenizer, input_ids[0].tolist(), X) for response in res: outputs = response # Print the current query and response print(f"Query {idx + 2}:") print(query) print(f"Response {idx + 2}:") print(outputs.strip()) print("********") ALLResponses.append(outputs.strip()) return ALLResponses def Turn_On_LLaVa(): # Load the LLaVa model DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" torch.cuda.set_device(0) device = torch.device('cuda') print('Current Device :', torch.cuda.current_device()) MODEL_NAME = "FlowSteering/llava/llava_weights/" # PATH to the LLaVA weights model_name = os.path.expanduser(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) dtypePerDevice = torch.float16 model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=dtypePerDevice, use_cache=True) model.to(device=device, dtype=dtypePerDevice) image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) vision_tower = model.get_model().vision_tower[0] vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=dtypePerDevice, low_cpu_mem_usage=True) model.to(device=device, dtype=dtypePerDevice) model.get_model().vision_tower[0] = vision_tower vision_tower.to(device=device, dtype=dtypePerDevice) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) return model, image_processor, tokenizer, device def load_param(MODEL_NAME, model, tokenizer, initial_query): model_name = os.path.expanduser(MODEL_NAME) image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: tokenizer.add_tokens( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True ) vision_tower = model.get_model().vision_tower[0] vision_tower = CLIPVisionModel.from_pretrained( vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).cuda() model.get_model().vision_tower[0] = vision_tower if vision_tower.device.type == "meta": vision_tower = CLIPVisionModel.from_pretrained( vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).cuda() model.get_model().vision_tower[0] = vision_tower else: vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IMAGE_PATCH_TOKEN] )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: ( vision_config.im_start_token, vision_config.im_end_token, ) = tokenizer.convert_tokens_to_ids( [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 unnorm = UnNormalize(image_processor.image_mean, image_processor.image_std) norm = Normalize(image_processor.image_mean, image_processor.image_std) embeds = model.model.embed_tokens.cuda() projector = model.model.mm_projector.cuda() for param in vision_tower.parameters(): param.requires_grad = False for param in model.parameters(): param.requires_grad = False for param in projector.parameters(): param.requires_grad = False for param in embeds.parameters(): param.requires_grad = False for param in model.model.parameters(): param.requires_grad = False qs = initial_query if mm_use_im_start_end: qs = ( qs + "\n" + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN ) else: qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt_multimodal" else: conv_mode = "multimodal" if conv_mode is not None and conv_mode != conv_mode: print( "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( conv_mode, conv_mode, conv_mode ) ) else: conv_mode = conv_mode conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() inputs = tokenizer([prompt]) input_ids = torch.as_tensor(inputs.input_ids).cuda() return ( tokenizer, image_processor, vision_tower, unnorm, norm, embeds, projector, prompt, input_ids, ) def Run_LLaVa(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor): reply = run_result(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor) return reply