Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import clip | |
| import pickle | |
| def do_batch(batch, embeddings): | |
| image_batch = torch.tensor(np.stack(batch)).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_batch).float() | |
| embeddings += image_features.cpu().numpy().tolist() | |
| print(f"{len(embeddings)} done") | |
| sys.stdout.flush() | |
| # even though it's not worth bothering with cuda, | |
| # because 98% of the run time is preprocessing on the cpu. | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load('RN50', device=device) | |
| limit = 1e9 | |
| batch_size = 100 | |
| output_filename = sys.argv[1] | |
| assert output_filename.endswith("pkl"), "first argument is the output pickle" | |
| assert sys.argv[2] in ("thumbs", "no-thumbs"), "second argument either thumbs or no-thumbs" | |
| do_thumbs = sys.argv[2] == "thumbs" | |
| def save(output_filename, embeddings, filenames): | |
| embeddings = np.array(embeddings) | |
| assert len(embeddings) == len(filenames) | |
| print(f"processed {len(embeddings)} images") | |
| data = {"embeddings": embeddings, "filenames": filenames} | |
| if do_thumbs: | |
| assert len(embeddings) == len(thumbs) | |
| data["thumbs"] = thumbs | |
| with open(output_filename, "wb") as f: | |
| pickle.dump(data, f) | |
| embeddings = [] | |
| filenames = [] | |
| thumbs = [] | |
| print("starting processing") | |
| batch = [] | |
| batch_count = 0 | |
| for filename in sys.stdin: | |
| filename = filename.rstrip() | |
| if filename.lower().endswith("jpg") or filename.lower().endswith("jpeg"): | |
| try: | |
| rgb = Image.open(filename).convert("RGB") | |
| img = preprocess(rgb) | |
| batch.append(img) | |
| filenames.append(filename) | |
| if len(batch) >= batch_size: | |
| do_batch(batch, embeddings) | |
| batch = [] | |
| batch_count += 1 | |
| if batch_count % 200 == 0: | |
| save(output_filename, embeddings, filenames) | |
| if do_thumbs: | |
| rgb.thumbnail((128, 128)) | |
| thumb = np.array(rgb) | |
| thumbs.append(thumb) | |
| if len(filenames) >= limit: | |
| break | |
| except KeyboardInterrupt: | |
| raise | |
| except: | |
| print(f"ERROR, skipping {filename}") | |
| sys.stdout.flush() | |
| # remaining | |
| if len(batch) > 0: | |
| do_batch(batch, embeddings) | |
| save(output_filename, embeddings, filenames) | |