id-align / scripts /archived /data_info.py
zooblastlbz's picture
Upload folder using huggingface_hub
a9e1e1a verified
import json
import os
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
def load_data(json_path):
with open(json_path, "r") as f:
return json.load(f)
def filter_data(data):
# filtered_data = [item for item in data if "image" in item and "text" in item["image"]]
filtered_data = [item for item in data if "image" in item]
return filtered_data
from multiprocessing import Pool
import functools
def calculate_image_dimension(item, images_folder):
image_path = os.path.join(images_folder, item["image"])
try:
with Image.open(image_path) as img:
width, height = img.size
return width, height
except Exception as e:
print(f"Error opening {image_path}: {e}")
return None, None
def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256):
with Pool(num_processes) as p:
dimensions = list(tqdm(p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), filtered_data), total=len(filtered_data), desc="Calculating image dimensions"))
widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None])
return list(widths), list(heights)
def tokenize(text):
return text.split()
def calculate_tokenized_lengths(data):
lengths = []
for item in tqdm(data, desc="Tokenizing conversations"):
for conversation in item["conversations"]:
tokenized_value = tokenize(conversation["value"])
lengths.append(len(tokenized_value))
return lengths
import argparse
def main():
parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.")
parser.add_argument("--json_path", type=str, help="Path to the JSON file containing data.")
parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing images.")
args = parser.parse_args()
llava_instruct_name = args.json_path.split("/")[-1].replace(".json", "")
json_path = args.json_path
llava_instruct_name = os.path.basename(json_path).replace(".json", "")
images_folder = args.images_folder
data = load_data(json_path)
filtered_data = filter_data(data)
if len(filtered_data) != 0:
print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}")
widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder)
max_width = max(widths)
max_height = max(heights)
print(f"Max width: {max_width}, Max height: {max_height}")
tokenized_lengths = calculate_tokenized_lengths(data)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
if len(filtered_data) != 0:
# Plot 2D histogram
if min(widths) == max(widths):
widths_bins = [min(widths), max(widths) + 1]
else:
widths_bins = np.arange(min(widths), max(widths) + 100, 100)
if min(heights) == max(heights):
heights_bins = [min(heights), max(heights) + 1]
else:
heights_bins = np.arange(min(heights), max(heights) + 100, 100)
h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True)
fig.colorbar(image, ax=ax1)
ax1.set_xlabel("Width")
ax1.set_ylabel("Height")
ax1.set_title(f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}", fontsize=10)
# Plot histogram
hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 100))
bins = np.arange(0, max(tokenized_lengths) + 10, 100)
ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True)
# Display every nth label on the x-axis
n = 8 # Adjust this value to control the number of labels displayed
ticks = bins[::n]
tick_labels = [int(tick) for tick in ticks]
ax2.set_xticks(ticks)
ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8)
ax2.set_xlim(min(bin_edges), max(bin_edges))
ax2.set_xlabel("Tokenized Length")
ax2.set_ylabel("Count (log scale)")
ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8)
plt.tight_layout()
plt.savefig(f"/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png")
print(f"Plots saved to /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png")
if __name__ == "__main__":
main()