File size: 1,694 Bytes
97a17c2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import json
import pandas as pd
from sentence_transformers import SentenceTransformer
from pathlib import Path
from tqdm import tqdm
def extract_caption(text_block):
for line in text_block.splitlines():
if "CAPTION:" in line.upper():
return line.split("CAPTION:")[-1].strip()
return ""
def load_captions_from_files(json_files):
all_paths = []
all_captions = []
for json_path in tqdm(json_files, desc="Reading files"):
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for img_path, outer_list in data.items():
if not outer_list or not outer_list[0]:
continue
text_block = outer_list[0][0]
caption = extract_caption(text_block)
if caption:
all_paths.append(img_path)
all_captions.append(caption)
return all_paths, all_captions
def compute_and_save_embeddings(json_files, output_csv):
model = SentenceTransformer('all-MiniLM-L6-v2')
image_paths, captions = load_captions_from_files(json_files)
if not captions:
print("No valid captions found across input files.")
return
embeddings = model.encode(captions, show_progress_bar=True)
df = pd.DataFrame(embeddings)
df.insert(0, "image_path", image_paths)
df.to_csv(output_csv, index=False)
print(f"Saved {len(df)} embeddings from {len(json_files)} files to {output_csv}")
# Example usage
if __name__ == "__main__":
import glob
# Collect multiple JSON files
files = glob.glob("./MBD_text/*.json") # or provide manually
compute_and_save_embeddings(files, "combined_caption_embeddings.csv")
|