Spaces:
Build error
Build error
| import argparse | |
| import csv | |
| import os | |
| import jax.numpy as jnp | |
| from jax import jit | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from config import MODEL_LIST | |
| from utils import load_model | |
| def main(args): | |
| root = args.image_path | |
| files = list(os.listdir(root)) | |
| for f in files: | |
| assert f[-4:] == ".jpg" | |
| for model_name in MODEL_LIST: | |
| model, processor = load_model(f"koclip/{model_name}") | |
| with tqdm(total=len(files)) as pbar: | |
| for counter in range(0, len(files), args.batch_size): | |
| images = [] | |
| image_ids = [] | |
| for idx in range(counter, min(len(files), counter + args.batch_size)): | |
| file_ = files[idx] | |
| image = Image.open(os.path.join(root, file_)).convert("RGB") | |
| images.append(image) | |
| image_ids.append(file_) | |
| pbar.update(args.batch_size) | |
| try: | |
| inputs = processor( | |
| text=[""], images=images, return_tensors="jax", padding=True | |
| ) | |
| except: | |
| print(image_ids) | |
| break | |
| inputs["pixel_values"] = jnp.transpose( | |
| inputs["pixel_values"], axes=[0, 2, 3, 1] | |
| ) | |
| features = model(**inputs).image_embeds | |
| with open(os.path.join(args.out_path, f"{model_name}.tsv"), "a+") as f: | |
| writer = csv.writer(f, delimiter="\t") | |
| for image_id, feature in zip(image_ids, features): | |
| writer.writerow( | |
| [image_id, ",".join(map(lambda x: str(x), feature))] | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--batch_size", default=16) | |
| parser.add_argument("--image_path", default="images") | |
| parser.add_argument("--out_path", default="features") | |
| args = parser.parse_args() | |
| main(args) | |