|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
|
|
|
import requests |
|
|
import torch |
|
|
import torch.multiprocessing as mp |
|
|
from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
from align_anything.utils.device_utils import set_device, torch_gc |
|
|
|
|
|
ignore_index = -100 |
|
|
|
|
|
|
|
|
def safe_torch_save(obj, file_path): |
|
|
"""安全地保存torch对象,自动创建目录""" |
|
|
try: |
|
|
|
|
|
file_path = Path(file_path) |
|
|
|
|
|
|
|
|
file_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(obj, file_path) |
|
|
return str(file_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 保存失败: {e}") |
|
|
print(f"尝试保存到: {file_path}") |
|
|
|
|
|
|
|
|
backup_dir = Path.home() / "torch_cache" |
|
|
backup_dir.mkdir(parents=True, exist_ok=True) |
|
|
backup_path = backup_dir / file_path.name |
|
|
torch.save(obj, backup_path) |
|
|
print(f"✅ 已保存到备用位置: {backup_path}") |
|
|
return str(backup_path) |
|
|
|
|
|
|
|
|
def load_image(image_path: str): |
|
|
try: |
|
|
if image_path.startswith('http'): |
|
|
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') |
|
|
else: |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
return image |
|
|
except Exception as e: |
|
|
print(f'Error occurred when dealing with {image_path}: {e}') |
|
|
raise Exception |
|
|
|
|
|
|
|
|
def format_sample_janus(piece, vl_chat_processor): |
|
|
sample = { |
|
|
'input_text': piece['prompt'], |
|
|
'source_image': piece['source_image'], |
|
|
'output_image': load_image(piece['image']), |
|
|
} |
|
|
return sample |
|
|
|
|
|
|
|
|
def tokenize_sample(vl_chat_processor, vl_gpt, vl_image_processor, formatted_sample): |
|
|
input_img_tokens = (vl_chat_processor.image_start_tag + |
|
|
vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens |
|
|
+ vl_chat_processor.image_end_tag + |
|
|
vl_chat_processor.image_start_tag + |
|
|
vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + |
|
|
vl_chat_processor.image_end_tag) |
|
|
output_img_tokens = vl_chat_processor.image_start_tag |
|
|
print(f'input_img_tokens: ') |
|
|
print(len(input_img_tokens)) |
|
|
print(vl_chat_processor.image_end_id) |
|
|
print(len(vl_chat_processor.image_tag)) |
|
|
print(vl_chat_processor.image_tag) |
|
|
print(len(vl_chat_processor.pad_tag)) |
|
|
print(f'{vl_chat_processor.image_tag} vl_chat_processor.num_image_tokens :',vl_chat_processor.num_image_tokens) |
|
|
print(f'{vl_chat_processor.pad_tag} vl_chat_processor.num_image_tokens :',vl_chat_processor.num_image_tokens) |
|
|
print() |
|
|
prompts = input_img_tokens + formatted_sample['input_text'] |
|
|
|
|
|
conversation = [ |
|
|
{'role': 'User', 'content': prompts}, |
|
|
{'role': 'Assistant', 'content': ''}, |
|
|
] |
|
|
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( |
|
|
conversations=conversation, |
|
|
sft_format=vl_chat_processor.sft_format, |
|
|
system_prompt='', |
|
|
) |
|
|
|
|
|
|
|
|
prompt = sft_format + vl_chat_processor.image_start_tag |
|
|
input_ids = vl_chat_processor.tokenizer.encode(prompt) |
|
|
input_ids = torch.LongTensor(input_ids).to(vl_gpt.device) |
|
|
xpp = (input_ids == vl_chat_processor.image_end_id).nonzero() |
|
|
print(xpp) |
|
|
print(len(input_ids)) |
|
|
|
|
|
pixel_values = ( |
|
|
vl_image_processor([formatted_sample['output_image']], return_tensors='pt')['pixel_values'] |
|
|
.to(vl_gpt.device) |
|
|
.to(torch.bfloat16) |
|
|
) |
|
|
( |
|
|
quant, |
|
|
(vq_loss, commit_loss, entropy_loss), |
|
|
(perplexity, min_encodings, min_encoding_indices), |
|
|
) = vl_gpt.gen_vision_model.encode(pixel_values) |
|
|
full_input_ids = torch.cat([input_ids, min_encoding_indices]) |
|
|
labels = full_input_ids.clone() |
|
|
labels[: len(input_ids)] = ignore_index |
|
|
|
|
|
return { |
|
|
'input_ids': full_input_ids.to('cpu'), |
|
|
'labels': labels.to('cpu'), |
|
|
'source_image': formatted_sample['source_image'], |
|
|
'task': 'generation', |
|
|
} |
|
|
|
|
|
|
|
|
def process_data(gpu, chunk, model_path, output_paths, cache_path): |
|
|
"""修复后的process_data函数""" |
|
|
try: |
|
|
|
|
|
cache_path = os.path.abspath(cache_path) |
|
|
print(f'GPU {gpu}: 使用缓存路径: {cache_path}') |
|
|
|
|
|
|
|
|
if not os.path.exists(cache_path): |
|
|
try: |
|
|
os.makedirs(cache_path, exist_ok=True) |
|
|
print(f'GPU {gpu}: 创建缓存目录: {cache_path}') |
|
|
except Exception as e: |
|
|
print(f'GPU {gpu}: 创建缓存目录失败: {e}') |
|
|
|
|
|
cache_path = os.path.join(os.path.expanduser("~"), "torch_cache") |
|
|
os.makedirs(cache_path, exist_ok=True) |
|
|
print(f'GPU {gpu}: 使用备用缓存目录: {cache_path}') |
|
|
|
|
|
device = set_device(gpu) |
|
|
print(f'Initializing Model on {device}') |
|
|
|
|
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path, device=device) |
|
|
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device) |
|
|
vl_gpt = vl_gpt.to(torch.bfloat16).eval() |
|
|
vl_image_processor = VLMImageProcessor.from_pretrained(model_path, device=device) |
|
|
|
|
|
print(f'Finished Initializing Model on {device}') |
|
|
|
|
|
local_output_paths = [] |
|
|
for i, piece in enumerate(tqdm(chunk, desc=f'Processing on GPU {gpu}')): |
|
|
try: |
|
|
print(f'GPU {gpu}: Processing sample {i + 1}/{len(chunk)}') |
|
|
formatted_sample = format_sample_janus(piece, vl_chat_processor) |
|
|
sample = tokenize_sample(vl_chat_processor, vl_gpt, vl_image_processor, formatted_sample) |
|
|
|
|
|
file_name = f"gpu_{gpu}_{str(uuid.uuid4())}.pt" |
|
|
file_path = os.path.join(cache_path, file_name) |
|
|
|
|
|
|
|
|
saved_path = safe_torch_save(sample, file_path) |
|
|
local_output_paths.append(saved_path) |
|
|
|
|
|
del sample |
|
|
torch_gc() |
|
|
|
|
|
except Exception as e: |
|
|
print(f'GPU {gpu}: 处理样本 {i} 时出错: {e}') |
|
|
continue |
|
|
|
|
|
output_paths.extend(local_output_paths) |
|
|
print(f'GPU {gpu}: Processed {len(local_output_paths)} samples successfully') |
|
|
|
|
|
except Exception as e: |
|
|
print(f'GPU {gpu}: process_data 函数出错: {e}') |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--input_path', type=str, required=True) |
|
|
parser.add_argument('--output_path', type=str, required=True) |
|
|
parser.add_argument('--model_path', type=str, required=True) |
|
|
parser.add_argument('--cache_dir', type=str, default='.cache') |
|
|
parser.add_argument('--num_processes', type=int, default=16) |
|
|
parser.add_argument('--num_gpus', type=int, default=8) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
input_path = args.input_path |
|
|
output_path = args.output_path |
|
|
model_path = args.model_path |
|
|
cache_path = os.path.abspath(args.cache_dir) |
|
|
|
|
|
print(f"输入路径: {input_path}") |
|
|
print(f"输出路径: {output_path}") |
|
|
print(f"模型路径: {model_path}") |
|
|
print(f"缓存路径: {cache_path}") |
|
|
print(f"进程数: {args.num_processes}") |
|
|
print(f"GPU数: {args.num_gpus}") |
|
|
|
|
|
|
|
|
try: |
|
|
if not os.path.exists(cache_path): |
|
|
os.makedirs(cache_path, exist_ok=True) |
|
|
print(f"✅ 创建缓存目录: {cache_path}") |
|
|
else: |
|
|
print(f"✅ 缓存目录已存在: {cache_path}") |
|
|
except Exception as e: |
|
|
print(f"❌ 创建缓存目录失败: {e}") |
|
|
|
|
|
cache_path = os.path.join(os.path.expanduser("~"), "torch_cache") |
|
|
os.makedirs(cache_path, exist_ok=True) |
|
|
print(f"✅ 使用备用缓存目录: {cache_path}") |
|
|
|
|
|
|
|
|
output_dir = os.path.dirname(os.path.abspath(output_path)) |
|
|
if not os.path.exists(output_dir): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
print(f"✅ 创建输出目录: {output_dir}") |
|
|
|
|
|
|
|
|
if not os.path.exists(input_path): |
|
|
raise FileNotFoundError(f"输入文件不存在: {input_path}") |
|
|
|
|
|
with open(input_path) as f: |
|
|
input_data = json.load(f) |
|
|
|
|
|
num_processes = args.num_processes |
|
|
num_gpus = args.num_gpus |
|
|
|
|
|
|
|
|
try: |
|
|
mp.set_start_method('spawn', force=True) |
|
|
except RuntimeError: |
|
|
|
|
|
pass |
|
|
|
|
|
output_paths = mp.Manager().list() |
|
|
|
|
|
target = input_data |
|
|
print(f'Full Length: {len(target)}') |
|
|
|
|
|
if len(target) == 0: |
|
|
print("❌ 输入数据为空") |
|
|
return |
|
|
|
|
|
chunks = [target[i::num_processes] for i in range(num_processes)] |
|
|
print(f"数据分块: {[len(chunk) for chunk in chunks]}") |
|
|
|
|
|
processes = [] |
|
|
for id in range(num_processes): |
|
|
gpu = id % num_gpus |
|
|
print(f"启动进程 {id}, 使用GPU {gpu}, 处理 {len(chunks[id])} 个样本") |
|
|
|
|
|
p = mp.Process( |
|
|
target=process_data, |
|
|
args=(gpu, chunks[id], model_path, output_paths, cache_path) |
|
|
) |
|
|
p.start() |
|
|
processes.append(p) |
|
|
|
|
|
|
|
|
for i, p in enumerate(processes): |
|
|
print(f"等待进程 {i} 完成...") |
|
|
p.join() |
|
|
if p.exitcode != 0: |
|
|
print(f"⚠️ 进程 {i} 退出码: {p.exitcode}") |
|
|
|
|
|
output_paths = list(output_paths) |
|
|
print(f"收集到 {len(output_paths)} 个输出文件") |
|
|
|
|
|
if len(output_paths) == 0: |
|
|
print("❌ 没有成功处理的样本") |
|
|
return |
|
|
|
|
|
all_data = [] |
|
|
failed_loads = 0 |
|
|
for path in tqdm(output_paths, desc="加载处理后的数据"): |
|
|
try: |
|
|
data = torch.load(path, weights_only=False) |
|
|
all_data.append(data) |
|
|
except Exception as e: |
|
|
print(f"❌ 加载文件失败 {path}: {e}") |
|
|
failed_loads += 1 |
|
|
|
|
|
if failed_loads > 0: |
|
|
print(f"⚠️ {failed_loads} 个文件加载失败") |
|
|
|
|
|
torch.set_printoptions(threshold=torch.inf) |
|
|
print(f'Effective Length: {len(all_data)}') |
|
|
|
|
|
if len(all_data) == 0: |
|
|
print("❌ 没有有效数据可保存") |
|
|
return |
|
|
|
|
|
try: |
|
|
torch.save(all_data, output_path) |
|
|
print(f"✅ 成功保存到: {output_path}") |
|
|
except Exception as e: |
|
|
print(f"❌ 保存最终结果失败: {e}") |
|
|
|
|
|
backup_path = os.path.join(os.path.dirname(output_path), f"backup_{os.path.basename(output_path)}") |
|
|
torch.save(all_data, backup_path) |
|
|
print(f"✅ 已保存到备用位置: {backup_path}") |
|
|
|
|
|
|
|
|
print("清理临时文件...") |
|
|
for path in output_paths: |
|
|
try: |
|
|
if os.path.exists(path): |
|
|
os.remove(path) |
|
|
except Exception as e: |
|
|
print(f"清理文件失败 {path}: {e}") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |