dia-gov's picture
Upload 300 files
fff4338 verified
import os
from io import BytesIO
import requests
import torchvision.transforms as T
from PIL import Image
from FlowSteeringWorm.llava.conversation import conv_templates
from FlowSteeringWorm.llava.model import *
from transformers import AutoTokenizer
from transformers import CLIPVisionModel, CLIPImageProcessor
transform = T.ToPILImage()
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
from transformers import logging
logging.set_verbosity_error()
SEED = 10
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
TEMPERATURE = 0.1
MAX_NEW_TOKENS = 1024
CONTEXT_LEN = 2048
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
from utils.encryption import encrypt_data, decrypt_data
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
tensor = tensor.clone()
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
tensor = tensor.clone()
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def generate_stream(model, prompt, tokenizer, input_ids, images=None):
temperature = TEMPERATURE
max_new_tokens = MAX_NEW_TOKENS
context_len = CONTEXT_LEN
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
stop_idx = 2
ori_prompt = prompt
image_args = {"images": images}
output_ids = list(input_ids)
pred_ids = []
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
past_key_values = None
for i in range(max_new_tokens):
if i == 0 and past_key_values is None:
out = model(
torch.as_tensor([input_ids]).cuda(),
use_cache=True,
output_hidden_states=True,
**image_args,
)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
1, past_key_values[0][0].shape[-2] + 1, device="cuda"
)
out = model(
input_ids=torch.as_tensor([[token]], device="cuda"),
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)
logits = out.logits
past_key_values = out.past_key_values
# yield out
last_token_logits = logits[0][-1]
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
pred_ids.append(token)
if stop_idx is not None and token == stop_idx:
stopped = True
elif token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i != 0 and i % 1024 == 0 or i == max_new_tokens - 1 or stopped:
cur_out = tokenizer.decode(pred_ids, skip_special_tokens=True)
pos = -1 # cur_out.rfind(stop_str)
if pos != -1:
cur_out = cur_out[:pos]
stopped = True
output = ori_prompt + cur_out
# print('output', output)
ret = {
"text": output,
"error_code": 0,
}
yield cur_out
if stopped:
break
if past_key_values is not None:
del past_key_values
def run_result(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor):
device = 'cuda'
X = load_image(X)
print("Image: ")
# load the image
X = image_processor.preprocess(X, return_tensors='pt')['pixel_values'][0].unsqueeze(0).half().cuda()
# Generate the output with initial query
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device=device)
res = generate_stream(model, prompt, tokenizer, input_ids[0].tolist(), X)
for response1 in res:
outputs1 = response1
print(f'Query 1:')
print(initial_query)
print(f'Response 1:')
print(outputs1.strip())
print('********')
ALLResponses = []
ALLResponses.append(outputs1.strip())
# Generate the outputs with further queries
for idx, query in enumerate(query_list):
if idx == 0:
# Update current prompt with the initial prompt and first output
new_prompt = prompt + outputs1 + "\n###Human: " + query + "\n###Assistant:"
else:
# Update current prompt with the previous prompt and latest output
new_prompt = (
new_prompt + outputs + "\n###Human: " + query + "\n###Assistant:"
)
input_ids = tokenizer.encode(new_prompt, return_tensors="pt").cuda()
# Generate the response using the updated prompt
res = generate_stream(model, new_prompt, tokenizer, input_ids[0].tolist(), X)
for response in res:
outputs = response
# Print the current query and response
print(f"Query {idx + 2}:")
print(query)
print(f"Response {idx + 2}:")
print(outputs.strip())
print("********")
ALLResponses.append(outputs.strip())
return ALLResponses
def Turn_On_LLaVa(): # Load the LLaVa model
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
torch.cuda.set_device(0)
device = torch.device('cuda')
print('Current Device :', torch.cuda.current_device())
MODEL_NAME = "FlowSteering/llava/llava_weights/" # PATH to the LLaVA weights
model_name = os.path.expanduser(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
dtypePerDevice = torch.float16
model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=dtypePerDevice,
use_cache=True)
model.to(device=device, dtype=dtypePerDevice)
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_tower = model.get_model().vision_tower[0]
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=dtypePerDevice,
low_cpu_mem_usage=True)
model.to(device=device, dtype=dtypePerDevice)
model.get_model().vision_tower[0] = vision_tower
vision_tower.to(device=device, dtype=dtypePerDevice)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
return model, image_processor, tokenizer, device
def load_param(MODEL_NAME, model, tokenizer, initial_query):
model_name = os.path.expanduser(MODEL_NAME)
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower)
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
vision_tower = model.get_model().vision_tower[0]
vision_tower = CLIPVisionModel.from_pretrained(
vision_tower.config._name_or_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).cuda()
model.get_model().vision_tower[0] = vision_tower
if vision_tower.device.type == "meta":
vision_tower = CLIPVisionModel.from_pretrained(
vision_tower.config._name_or_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).cuda()
model.get_model().vision_tower[0] = vision_tower
else:
vision_tower.to(device="cuda", dtype=torch.float16)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN]
)[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
(
vision_config.im_start_token,
vision_config.im_end_token,
) = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
)
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
unnorm = UnNormalize(image_processor.image_mean, image_processor.image_std)
norm = Normalize(image_processor.image_mean, image_processor.image_std)
embeds = model.model.embed_tokens.cuda()
projector = model.model.mm_projector.cuda()
for param in vision_tower.parameters():
param.requires_grad = False
for param in model.parameters():
param.requires_grad = False
for param in projector.parameters():
param.requires_grad = False
for param in embeds.parameters():
param.requires_grad = False
for param in model.model.parameters():
param.requires_grad = False
qs = initial_query
if mm_use_im_start_end:
qs = (
qs
+ "\n"
+ DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ DEFAULT_IM_END_TOKEN
)
else:
qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
if "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt_multimodal"
else:
conv_mode = "multimodal"
if conv_mode is not None and conv_mode != conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, conv_mode, conv_mode
)
)
else:
conv_mode = conv_mode
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()
return (
tokenizer,
image_processor,
vision_tower,
unnorm,
norm,
embeds,
projector,
prompt,
input_ids,
)
def Run_LLaVa(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor):
reply = run_result(X, prompt, initial_query, query_list, model, tokenizer, unnorm, image_processor)
return reply