|
|
import json |
|
|
import os |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def load_data(json_path): |
|
|
with open(json_path, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def filter_data(data): |
|
|
|
|
|
filtered_data = [item for item in data if "image" in item] |
|
|
return filtered_data |
|
|
|
|
|
|
|
|
from multiprocessing import Pool |
|
|
import functools |
|
|
|
|
|
|
|
|
def calculate_image_dimension(item, images_folder): |
|
|
image_path = os.path.join(images_folder, item["image"]) |
|
|
try: |
|
|
with Image.open(image_path) as img: |
|
|
width, height = img.size |
|
|
return width, height |
|
|
except Exception as e: |
|
|
print(f"Error opening {image_path}: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256): |
|
|
with Pool(num_processes) as p: |
|
|
dimensions = list(tqdm(p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), filtered_data), total=len(filtered_data), desc="Calculating image dimensions")) |
|
|
widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None]) |
|
|
return list(widths), list(heights) |
|
|
|
|
|
|
|
|
def tokenize(text): |
|
|
return text.split() |
|
|
|
|
|
|
|
|
def calculate_tokenized_lengths(data): |
|
|
lengths = [] |
|
|
for item in tqdm(data, desc="Tokenizing conversations"): |
|
|
for conversation in item["conversations"]: |
|
|
tokenized_value = tokenize(conversation["value"]) |
|
|
lengths.append(len(tokenized_value)) |
|
|
return lengths |
|
|
|
|
|
|
|
|
import argparse |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.") |
|
|
parser.add_argument("--json_path", type=str, help="Path to the JSON file containing data.") |
|
|
parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing images.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
llava_instruct_name = args.json_path.split("/")[-1].replace(".json", "") |
|
|
json_path = args.json_path |
|
|
llava_instruct_name = os.path.basename(json_path).replace(".json", "") |
|
|
images_folder = args.images_folder |
|
|
|
|
|
data = load_data(json_path) |
|
|
filtered_data = filter_data(data) |
|
|
|
|
|
if len(filtered_data) != 0: |
|
|
print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}") |
|
|
widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder) |
|
|
max_width = max(widths) |
|
|
max_height = max(heights) |
|
|
print(f"Max width: {max_width}, Max height: {max_height}") |
|
|
|
|
|
tokenized_lengths = calculate_tokenized_lengths(data) |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) |
|
|
|
|
|
if len(filtered_data) != 0: |
|
|
|
|
|
if min(widths) == max(widths): |
|
|
widths_bins = [min(widths), max(widths) + 1] |
|
|
else: |
|
|
widths_bins = np.arange(min(widths), max(widths) + 100, 100) |
|
|
|
|
|
if min(heights) == max(heights): |
|
|
heights_bins = [min(heights), max(heights) + 1] |
|
|
else: |
|
|
heights_bins = np.arange(min(heights), max(heights) + 100, 100) |
|
|
|
|
|
h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True) |
|
|
fig.colorbar(image, ax=ax1) |
|
|
ax1.set_xlabel("Width") |
|
|
ax1.set_ylabel("Height") |
|
|
ax1.set_title(f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}", fontsize=10) |
|
|
|
|
|
|
|
|
hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 100)) |
|
|
bins = np.arange(0, max(tokenized_lengths) + 10, 100) |
|
|
ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True) |
|
|
|
|
|
|
|
|
n = 8 |
|
|
ticks = bins[::n] |
|
|
tick_labels = [int(tick) for tick in ticks] |
|
|
ax2.set_xticks(ticks) |
|
|
ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8) |
|
|
|
|
|
ax2.set_xlim(min(bin_edges), max(bin_edges)) |
|
|
ax2.set_xlabel("Tokenized Length") |
|
|
ax2.set_ylabel("Count (log scale)") |
|
|
ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(f"/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png") |
|
|
print(f"Plots saved to /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|