| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import argparse |
| import json |
| import os |
| import uuid |
| from pathlib import Path |
|
|
| import requests |
| import torch |
| import torch.multiprocessing as mp |
| from janus.models import MultiModalityCausalLM, VLChatProcessor, VLMImageProcessor |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from align_anything.utils.device_utils import set_device, torch_gc |
|
|
| ignore_index = -100 |
|
|
|
|
| def safe_torch_save(obj, file_path): |
| """安全地保存torch对象,自动创建目录""" |
| try: |
| |
| file_path = Path(file_path) |
|
|
| |
| file_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| torch.save(obj, file_path) |
| return str(file_path) |
|
|
| except Exception as e: |
| print(f"❌ 保存失败: {e}") |
| print(f"尝试保存到: {file_path}") |
|
|
| |
| backup_dir = Path.home() / "torch_cache" |
| backup_dir.mkdir(parents=True, exist_ok=True) |
| backup_path = backup_dir / file_path.name |
| torch.save(obj, backup_path) |
| print(f"✅ 已保存到备用位置: {backup_path}") |
| return str(backup_path) |
|
|
|
|
| def load_image(image_path: str): |
| try: |
| if image_path.startswith('http'): |
| image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') |
| else: |
| image = Image.open(image_path).convert('RGB') |
| return image |
| except Exception as e: |
| print(f'Error occurred when dealing with {image_path}: {e}') |
| raise Exception |
|
|
|
|
| def format_sample_janus(piece, vl_chat_processor): |
| sample = { |
| 'input_text': piece['prompt'], |
| 'source_image': piece['source_image'], |
| 'output_image': load_image(piece['image']), |
| } |
| return sample |
|
|
|
|
| def tokenize_sample(vl_chat_processor, vl_gpt, vl_image_processor, formatted_sample): |
| input_img_tokens = (vl_chat_processor.image_start_tag + |
| vl_chat_processor.image_tag * vl_chat_processor.num_image_tokens |
| + vl_chat_processor.image_end_tag + |
| vl_chat_processor.image_start_tag + |
| vl_chat_processor.pad_tag * vl_chat_processor.num_image_tokens + |
| vl_chat_processor.image_end_tag) |
| output_img_tokens = vl_chat_processor.image_start_tag |
| print(f'input_img_tokens: ') |
| print(len(input_img_tokens)) |
| print(vl_chat_processor.image_end_id) |
| print(len(vl_chat_processor.image_tag)) |
| print(vl_chat_processor.image_tag) |
| print(len(vl_chat_processor.pad_tag)) |
| print(f'{vl_chat_processor.image_tag} vl_chat_processor.num_image_tokens :',vl_chat_processor.num_image_tokens) |
| print(f'{vl_chat_processor.pad_tag} vl_chat_processor.num_image_tokens :',vl_chat_processor.num_image_tokens) |
| print() |
| prompts = input_img_tokens + formatted_sample['input_text'] |
|
|
| conversation = [ |
| {'role': 'User', 'content': prompts}, |
| {'role': 'Assistant', 'content': ''}, |
| ] |
| sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( |
| conversations=conversation, |
| sft_format=vl_chat_processor.sft_format, |
| system_prompt='', |
| ) |
| |
|
|
| prompt = sft_format + vl_chat_processor.image_start_tag |
| input_ids = vl_chat_processor.tokenizer.encode(prompt) |
| input_ids = torch.LongTensor(input_ids).to(vl_gpt.device) |
| xpp = (input_ids == vl_chat_processor.image_end_id).nonzero() |
| print(xpp) |
| print(len(input_ids)) |
|
|
| pixel_values = ( |
| vl_image_processor([formatted_sample['output_image']], return_tensors='pt')['pixel_values'] |
| .to(vl_gpt.device) |
| .to(torch.bfloat16) |
| ) |
| ( |
| quant, |
| (vq_loss, commit_loss, entropy_loss), |
| (perplexity, min_encodings, min_encoding_indices), |
| ) = vl_gpt.gen_vision_model.encode(pixel_values) |
| full_input_ids = torch.cat([input_ids, min_encoding_indices]) |
| labels = full_input_ids.clone() |
| labels[: len(input_ids)] = ignore_index |
|
|
| return { |
| 'input_ids': full_input_ids.to('cpu'), |
| 'labels': labels.to('cpu'), |
| 'source_image': formatted_sample['source_image'], |
| 'task': 'generation', |
| } |
|
|
|
|
| def process_data(gpu, chunk, model_path, output_paths, cache_path): |
| """修复后的process_data函数""" |
| try: |
| |
| cache_path = os.path.abspath(cache_path) |
| print(f'GPU {gpu}: 使用缓存路径: {cache_path}') |
|
|
| |
| if not os.path.exists(cache_path): |
| try: |
| os.makedirs(cache_path, exist_ok=True) |
| print(f'GPU {gpu}: 创建缓存目录: {cache_path}') |
| except Exception as e: |
| print(f'GPU {gpu}: 创建缓存目录失败: {e}') |
| |
| cache_path = os.path.join(os.path.expanduser("~"), "torch_cache") |
| os.makedirs(cache_path, exist_ok=True) |
| print(f'GPU {gpu}: 使用备用缓存目录: {cache_path}') |
|
|
| device = set_device(gpu) |
| print(f'Initializing Model on {device}') |
|
|
| vl_chat_processor = VLChatProcessor.from_pretrained(model_path, device=device) |
| vl_gpt = MultiModalityCausalLM.from_pretrained(model_path, trust_remote_code=True).to(device) |
| vl_gpt = vl_gpt.to(torch.bfloat16).eval() |
| vl_image_processor = VLMImageProcessor.from_pretrained(model_path, device=device) |
|
|
| print(f'Finished Initializing Model on {device}') |
|
|
| local_output_paths = [] |
| for i, piece in enumerate(tqdm(chunk, desc=f'Processing on GPU {gpu}')): |
| try: |
| print(f'GPU {gpu}: Processing sample {i + 1}/{len(chunk)}') |
| formatted_sample = format_sample_janus(piece, vl_chat_processor) |
| sample = tokenize_sample(vl_chat_processor, vl_gpt, vl_image_processor, formatted_sample) |
|
|
| file_name = f"gpu_{gpu}_{str(uuid.uuid4())}.pt" |
| file_path = os.path.join(cache_path, file_name) |
|
|
| |
| saved_path = safe_torch_save(sample, file_path) |
| local_output_paths.append(saved_path) |
|
|
| del sample |
| torch_gc() |
|
|
| except Exception as e: |
| print(f'GPU {gpu}: 处理样本 {i} 时出错: {e}') |
| continue |
|
|
| output_paths.extend(local_output_paths) |
| print(f'GPU {gpu}: Processed {len(local_output_paths)} samples successfully') |
|
|
| except Exception as e: |
| print(f'GPU {gpu}: process_data 函数出错: {e}') |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_path', type=str, required=True) |
| parser.add_argument('--output_path', type=str, required=True) |
| parser.add_argument('--model_path', type=str, required=True) |
| parser.add_argument('--cache_dir', type=str, default='.cache') |
| parser.add_argument('--num_processes', type=int, default=16) |
| parser.add_argument('--num_gpus', type=int, default=8) |
|
|
| args = parser.parse_args() |
|
|
| input_path = args.input_path |
| output_path = args.output_path |
| model_path = args.model_path |
| cache_path = os.path.abspath(args.cache_dir) |
|
|
| print(f"输入路径: {input_path}") |
| print(f"输出路径: {output_path}") |
| print(f"模型路径: {model_path}") |
| print(f"缓存路径: {cache_path}") |
| print(f"进程数: {args.num_processes}") |
| print(f"GPU数: {args.num_gpus}") |
|
|
| |
| try: |
| if not os.path.exists(cache_path): |
| os.makedirs(cache_path, exist_ok=True) |
| print(f"✅ 创建缓存目录: {cache_path}") |
| else: |
| print(f"✅ 缓存目录已存在: {cache_path}") |
| except Exception as e: |
| print(f"❌ 创建缓存目录失败: {e}") |
| |
| cache_path = os.path.join(os.path.expanduser("~"), "torch_cache") |
| os.makedirs(cache_path, exist_ok=True) |
| print(f"✅ 使用备用缓存目录: {cache_path}") |
|
|
| |
| output_dir = os.path.dirname(os.path.abspath(output_path)) |
| if not os.path.exists(output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
| print(f"✅ 创建输出目录: {output_dir}") |
|
|
| |
| if not os.path.exists(input_path): |
| raise FileNotFoundError(f"输入文件不存在: {input_path}") |
|
|
| with open(input_path) as f: |
| input_data = json.load(f) |
|
|
| num_processes = args.num_processes |
| num_gpus = args.num_gpus |
|
|
| |
| try: |
| mp.set_start_method('spawn', force=True) |
| except RuntimeError: |
| |
| pass |
|
|
| output_paths = mp.Manager().list() |
|
|
| target = input_data |
| print(f'Full Length: {len(target)}') |
|
|
| if len(target) == 0: |
| print("❌ 输入数据为空") |
| return |
|
|
| chunks = [target[i::num_processes] for i in range(num_processes)] |
| print(f"数据分块: {[len(chunk) for chunk in chunks]}") |
|
|
| processes = [] |
| for id in range(num_processes): |
| gpu = id % num_gpus |
| print(f"启动进程 {id}, 使用GPU {gpu}, 处理 {len(chunks[id])} 个样本") |
|
|
| p = mp.Process( |
| target=process_data, |
| args=(gpu, chunks[id], model_path, output_paths, cache_path) |
| ) |
| p.start() |
| processes.append(p) |
|
|
| |
| for i, p in enumerate(processes): |
| print(f"等待进程 {i} 完成...") |
| p.join() |
| if p.exitcode != 0: |
| print(f"⚠️ 进程 {i} 退出码: {p.exitcode}") |
|
|
| output_paths = list(output_paths) |
| print(f"收集到 {len(output_paths)} 个输出文件") |
|
|
| if len(output_paths) == 0: |
| print("❌ 没有成功处理的样本") |
| return |
|
|
| all_data = [] |
| failed_loads = 0 |
| for path in tqdm(output_paths, desc="加载处理后的数据"): |
| try: |
| data = torch.load(path, weights_only=False) |
| all_data.append(data) |
| except Exception as e: |
| print(f"❌ 加载文件失败 {path}: {e}") |
| failed_loads += 1 |
|
|
| if failed_loads > 0: |
| print(f"⚠️ {failed_loads} 个文件加载失败") |
|
|
| torch.set_printoptions(threshold=torch.inf) |
| print(f'Effective Length: {len(all_data)}') |
|
|
| if len(all_data) == 0: |
| print("❌ 没有有效数据可保存") |
| return |
|
|
| try: |
| torch.save(all_data, output_path) |
| print(f"✅ 成功保存到: {output_path}") |
| except Exception as e: |
| print(f"❌ 保存最终结果失败: {e}") |
| |
| backup_path = os.path.join(os.path.dirname(output_path), f"backup_{os.path.basename(output_path)}") |
| torch.save(all_data, backup_path) |
| print(f"✅ 已保存到备用位置: {backup_path}") |
|
|
| |
| print("清理临时文件...") |
| for path in output_paths: |
| try: |
| if os.path.exists(path): |
| os.remove(path) |
| except Exception as e: |
| print(f"清理文件失败 {path}: {e}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |