| |
| |
| from accelerate.utils import gather_object |
|
|
| try: |
| from transformers import activations |
|
|
| activations.PytorchGELUTanh = activations.GELUTanh |
| except ImportError: |
| print("Note: Unable to apply PytorchGELUTanh patch. If you encounter an ImportError, please check the transformers version.") |
| |
|
|
| import os |
| import shutil |
| import json |
| import re |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| from datasets import load_dataset |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
| from accelerate import Accelerator |
|
|
| |
| try: |
| from qwen_vl_utils import process_vision_info |
| except ImportError: |
| print("Warning: Failed to import 'qwen_vl_utils.process_vision_info'.") |
|
|
|
|
| def process_vision_info(messages): |
| images = [] |
| for msg in messages: |
| if msg['role'] == 'user': |
| for content in msg['content']: |
| if content['type'] == 'image': |
| images.append(content['image']) |
| return images, None |
|
|
| Image.MAX_IMAGE_PIXELS = None |
|
|
| |
| try: |
| from rex_omni import RexOmniWrapper |
| except ImportError: |
| print("Warning: 'from rex_omni import RexOmniWrapper' failed.") |
| print("Using a dummy RexOmniWrapper (DummyRex) for testing only.") |
|
|
|
|
| class DummyRex: |
| def __init__(self, *args, **kwargs): |
| print("INFO: DUMMY: Using DummyRex detector.") |
|
|
| def inference(self, images, task, categories, **kwargs): |
| print("INFO: DUMMY: DummyRex returning fake center boxes.") |
|
|
| |
| results = [] |
|
|
| |
| if not isinstance(images, list): |
| images = [images] |
|
|
| for img in images: |
| if isinstance(img, Image.Image): |
| w, h = img.size |
| else: |
| w, h = 800, 600 |
| x0, y0 = w * 0.25, h * 0.25 |
| x1, y1 = w * 0.75, h * 0.75 |
| results.append({"extracted_predictions": {"anything": [{"type": "box", "coords": [x0, y0, x1, y1]}]}}) |
| return results |
|
|
|
|
| RexOmniWrapper = DummyRex |
|
|
|
|
| def _strip_tags(text, tag_name): |
| |
| if not isinstance(text, str): |
| text = str(text) |
| text = re.sub(rf'<{tag_name}>', '', text, flags=re.IGNORECASE) |
| text = re.sub(rf'</{tag_name}>', '', text, flags=re.IGNORECASE) |
| return text.strip() |
|
|
|
|
| |
|
|
| def _crop_and_expand_box(image, box, padding_pixels=20): |
| |
| x0, y0, x1, y1 = [int(c) for c in box] |
| img_w, img_h = image.size |
| x0_new = max(0, x0 - padding_pixels) |
| y0_new = max(0, y0 - padding_pixels) |
| x1_new = min(img_w, x1 + padding_pixels) |
| y1_new = min(img_h, y1 + padding_pixels) |
| return image.crop((x0_new, y0_new, x1_new, y1_new)) |
|
|
|
|
| |
| def _query_qwen_vl_BATCH(crop_images_list, model, processor, accelerator): |
| """ |
| Use Qwen-VL to query cropped image patches in batch and return a list of JSON strings. |
| """ |
| if not crop_images_list: |
| return [] |
|
|
| prompt = """This is an object from a CLEVR scene. Analyze the primary object in the image. |
| Respond *strictly* with a JSON list (containing one dictionary) in the following format: |
| [ |
| {"object": "object_name", "attributes": ["attr1", "attr2"]} |
| ] |
| - "object": The shape of the object (e.g., "sphere", "cube", "cylinder", "cone"). |
| - "attributes": A list of visual attributes (e.g., ["blue", "large", "metal", "shiny", "rubber"]). |
| Provide only the JSON list:""" |
|
|
| |
| template_messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": "placeholder.jpg"}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
|
|
| try: |
| |
| chat_prompt_text = processor.apply_chat_template( |
| template_messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| num_crops = len(crop_images_list) |
| batch_text = [chat_prompt_text] * num_crops |
| batch_images = crop_images_list |
|
|
| unwrapped_model = accelerator.unwrap_model(model) |
|
|
| |
| inputs = processor( |
| text=batch_text, |
| images=batch_images, |
| padding=True, |
| return_tensors="pt", |
| ).to(unwrapped_model.device) |
|
|
| |
| generated_ids = unwrapped_model.generate(**inputs, max_new_tokens=256, do_sample=False) |
|
|
| generated_ids_trimmed = [ |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| ] |
|
|
| |
| output_texts_list = processor.batch_decode( |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
|
|
| return output_texts_list |
|
|
| except Exception as e: |
| print(f"Qwen-VL batched inference failed: {e}") |
| |
| return ["[]"] * len(crop_images_list) |
|
|
|
|
| def _parse_qwen_json(response_text): |
| |
| try: |
| match = re.search(r'```json\s*(\[.*\])\s*```', response_text, re.DOTALL) |
| if match: |
| json_str = match.group(1) |
| return json.loads(json_str) |
| match = re.search(r'(\[.*\])', response_text, re.DOTALL) |
| if match: |
| json_str = match.group(1) |
| return json.loads(json_str) |
| return [] |
| except json.JSONDecodeError: |
| print(f"Failed to parse JSON: {response_text}") |
| return [] |
| except Exception as e: |
| print(f"Unknown error occurred while parsing JSON: {e}") |
| return [] |
|
|
|
|
| def _load_and_preprocess_data(base_output_dir, image_output_dir): |
| |
| print("Loading 'MMInstruction/Clevr_CoGenT_TrainA_R1'...") |
| try: |
| dataset = load_dataset("MMInstruction/Clevr_CoGenT_TrainA_R1", split='train') |
| except Exception as e: |
| print(f"Failed to load dataset 'MMInstruction/Clevr_CoGenT_TrainA_R1': {e}") |
| return [] |
|
|
| |
| |
| print(f"Loaded {len(dataset)} samples.") |
|
|
| job_list = [] |
| print("Preprocessing data (saving images and parsing text)...") |
| for i, example in enumerate(tqdm(dataset, desc="Preprocessing progress")): |
| prompt = example['problem'] |
| hint = _strip_tags(example['thinking'], 'think') |
| answer = _strip_tags(example['solution'], 'answer') |
|
|
| image = example['image'] |
| if not isinstance(image, Image.Image): |
| print(f"Warning: sample {i} is not a PIL image, skipped.") |
| continue |
|
|
| image_filename = f"clevr_cogent_trainA_r1_{i:07d}.jpg" |
| destination_image_path = os.path.join(image_output_dir, image_filename) |
|
|
| try: |
| os.makedirs(os.path.dirname(destination_image_path), exist_ok=True) |
| if not os.path.exists(destination_image_path): |
| image.convert("RGB").save(destination_image_path, "JPEG") |
| except Exception as e: |
| print(f"Warning: failed to save image for sample {i}, skipped. Error: {e}") |
| continue |
|
|
| job_list.append({ |
| "prompt": prompt, |
| "answer": answer, |
| "hint": hint, |
| "destination_image_path": destination_image_path |
| }) |
|
|
| print(f"Successfully preprocessed {len(job_list)} items.") |
|
|
| job_list_path = os.path.join(base_output_dir, "job_list.json") |
| with open(job_list_path, 'w', encoding='utf-8') as f: |
| json.dump(job_list, f) |
|
|
| print(f"Job list saved to: {job_list_path}") |
| return job_list |
|
|
|
|
| def main(): |
| |
| accelerator = Accelerator() |
|
|
| |
| MODEL_CONFIGS = { |
| "rex_path": "IDEA-Research/Rex-Omni", |
| "qwen_path": "Qwen/Qwen2.5-VL-32B-Instruct-AWQ" |
| } |
| OUTPUT_DIR = "/path/to/data/clevr_cogent_output" |
| IMAGE_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "images") |
| JSON_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "json") |
|
|
| os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) |
| os.makedirs(JSON_OUTPUT_DIR, exist_ok=True) |
|
|
| |
| job_list_path = os.path.join(OUTPUT_DIR, "job_list.json") |
| if accelerator.is_main_process: |
| print("Main process [Pre-processing]: loading and preprocessing data...") |
| _load_and_preprocess_data(OUTPUT_DIR, IMAGE_OUTPUT_DIR) |
|
|
| |
| accelerator.wait_for_everyone() |
|
|
| |
| if not accelerator.is_main_process: |
| print(f"Process {accelerator.process_index}: loading job_list.json...") |
| try: |
| with open(job_list_path, 'r', encoding='utf-8') as f: |
| all_jobs = json.load(f) |
| except Exception as e: |
| print(f"Process {accelerator.process_index} failed to load job_list.json: {e}") |
| return |
| total_jobs = len(all_jobs) |
| num_processes = accelerator.num_processes |
| jobs_per_process = total_jobs // num_processes |
| start_index = accelerator.process_index * jobs_per_process |
| end_index = (accelerator.process_index + 1) * jobs_per_process |
| if accelerator.is_last_process: |
| end_index = total_jobs |
| my_jobs = all_jobs[start_index:end_index] |
| print(f"[Process {accelerator.process_index}]:" |
| f" assigned {len(my_jobs)} jobs (indices from {start_index} to {end_index}).") |
|
|
| |
| try: |
| try: |
| from transformers import activations |
| activations.PytorchGELUTanh = activations.GELUTanh |
| except ImportError: |
| pass |
|
|
| print(f"[Process {accelerator.process_index}]: loading RexOmni...") |
| rex_model = RexOmniWrapper( |
| model_path=MODEL_CONFIGS['rex_path'], |
| backend="transformers", |
| max_tokens=2048, |
| temperature=0.0, |
| ) |
|
|
| print(f"[Process {accelerator.process_index}]: loading Qwen-VL...") |
| qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL_CONFIGS['qwen_path'], |
| torch_dtype="float16", |
| device_map="cuda", |
| attn_implementation="flash_attention_2" |
| ) |
| qwen_processor = AutoProcessor.from_pretrained(MODEL_CONFIGS['qwen_path']) |
|
|
| |
| |
| |
| qwen_model, rex_model = accelerator.prepare(qwen_model, rex_model) |
|
|
| print(f"[Process {accelerator.process_index}]: models loaded.") |
| except Exception as e: |
| print(f"[Process {accelerator.process_index}]: failed to load models: {e}") |
| import traceback |
| traceback.print_exc() |
| return |
|
|
| |
| REX_BATCH_SIZE = 16 |
|
|
| print(f"[Process {accelerator.process_index}]:" |
| f" starting to process {len(my_jobs)} jobs with Rex batch size {REX_BATCH_SIZE}.") |
|
|
| processed_metadata_list = [] |
|
|
| |
| for i in tqdm(range(0, len(my_jobs), REX_BATCH_SIZE), |
| desc=f"Worker {accelerator.process_index} batch progress", |
| disable=not accelerator.is_main_process): |
|
|
| |
| batch_jobs = my_jobs[i: i + REX_BATCH_SIZE] |
| batch_images = [] |
| batch_image_paths = [] |
|
|
| valid_jobs_in_batch = [] |
|
|
| for job in batch_jobs: |
| try: |
| img_path = job['destination_image_path'] |
| batch_image_paths.append(img_path) |
| batch_images.append(Image.open(img_path).convert("RGB")) |
| valid_jobs_in_batch.append(job) |
| except Exception as e: |
| print(f"[Process {accelerator.process_index}]:" |
| f" failed to load image {img_path}: {e}, skipping this image in this batch.") |
| |
|
|
| if not batch_images: |
| continue |
|
|
| try: |
| |
| |
| all_rex_results = rex_model.inference( |
| images=batch_images, |
| task="detection", |
| categories=["anything"] |
| ) |
|
|
| |
| |
| if len(all_rex_results) != len(valid_jobs_in_batch): |
| print(f"[Process {accelerator.process_index}]: Warning: RexOmni " |
| f"returned {len(all_rex_results)} results, but " |
| f"{len(valid_jobs_in_batch)} inputs were provided. Skipping this batch.") |
| continue |
|
|
| for job, image, rex_result in zip(valid_jobs_in_batch, batch_images, all_rex_results): |
|
|
| predictions = rex_result["extracted_predictions"] |
| detected_boxes = predictions.get("anything", []) |
|
|
| visual_facts = [] |
| crops_to_process = [] |
| box_coords_list = [] |
|
|
| |
| for annotation in detected_boxes: |
| if annotation.get("type") == "box" and len(annotation.get("coords", [])) == 4: |
| coords = annotation["coords"] |
| crop_image = _crop_and_expand_box(image, coords) |
| crops_to_process.append(crop_image) |
| box_coords_list.append(coords) |
|
|
| |
| if crops_to_process: |
| json_str_list = _query_qwen_vl_BATCH( |
| crops_to_process, qwen_model, qwen_processor, accelerator |
| ) |
|
|
| |
| for json_str, coords in zip(json_str_list, box_coords_list): |
| json_obj_list = _parse_qwen_json(json_str) |
| if json_obj_list: |
| try: |
| obj_dict = json_obj_list[0] |
| obj_dict["bounding_box"] = [round(c, 2) for c in coords] |
| visual_facts.append(obj_dict) |
| except (IndexError, TypeError, KeyError) as e: |
| print(f"[Process {accelerator.process_index}]: " |
| f"Error while parsing batched result: {e} | JSON: {json_str}") |
|
|
| |
| processed_metadata_list.append({ |
| "question": job['prompt'], |
| "answer": job['answer'], |
| "question_wo_prompt": job['prompt'], |
| "hint": job['hint'], |
| "image": job['destination_image_path'], |
| "visual_fact": visual_facts |
| }) |
| |
|
|
| except Exception as e: |
| print(f"[Process {accelerator.process_index}]: " |
| f"Error while processing batch {i // REX_BATCH_SIZE} (images {batch_image_paths}): {e}") |
| import traceback |
| traceback.print_exc() |
|
|
| |
|
|
| print(f"[Process {accelerator.process_index}]:" |
| f" process finished, handled {len(processed_metadata_list)} items.") |
|
|
| |
| print(f"[Process {accelerator.process_index}]: gathering results...") |
| all_results_list_of_lists = gather_object(processed_metadata_list) |
|
|
| |
| if accelerator.is_main_process: |
| print("Main process [Saving]: aggregating and saving all results...") |
|
|
| |
| |
| final_metadata_list = all_results_list_of_lists |
| |
|
|
| json_filename = os.path.join(JSON_OUTPUT_DIR, "clevr_cogent_trainA_r1_processed.json") |
|
|
| |
| print(f"Total number of aggregated items: {len(final_metadata_list)}") |
| if len(final_metadata_list) > 0: |
| print(f"Type of first item: {type(final_metadata_list[0])}") |
|
|
| print(f"\nSaving {len(final_metadata_list)} metadata entries to {json_filename}...") |
| with open(json_filename, 'w', encoding='utf-8') as f: |
| json.dump(final_metadata_list, f, indent=4, ensure_ascii=False) |
|
|
| print(f"\n--- Processing completed! ---") |
| print(f"All image files have been saved in: '{IMAGE_OUTPUT_DIR}'") |
| print(f"Final JSON file has been saved in: '{JSON_OUTPUT_DIR}'") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|