BrainExplore-demo / utils.py
Mattias Cosarinsky
update
39f341b
import json
from datasets import load_dataset
from PIL import Image
from huggingface_hub import hf_hub_download
from pathlib import Path
import os
# -------- CONFIG & STATE --------
HF_REPO = "mcosarinsky/BrainExplore-Images"
STRUCTURE_DATA = None
ROIS = ["Across","EBA","FBA-1","FBA-2","FFA-1","FFA-2","OFA","OPA","OWFA","PPA","RSC","VWFA-1","VWFA-2","hV4"]
# Load best_across.json once
BEST_ACROSS = None
def load_best_across():
global BEST_ACROSS
if BEST_ACROSS is not None:
return BEST_ACROSS
try:
path = hf_hub_download(repo_id=HF_REPO, filename="best_across.json", repo_type="dataset")
with open(path, "r") as f:
BEST_ACROSS = json.load(f)
return BEST_ACROSS
except Exception as e:
print(f"Error loading best_across.json: {e}")
BEST_ACROSS = {}
return BEST_ACROSS
# -------- 1. JSON STRUCTURE HELPERS --------
def load_structure(repo_id, filename="structure.json"):
"""
Downloads the structure.json file from the Hugging Face Hub and loads it into memory.
"""
global STRUCTURE_DATA
if STRUCTURE_DATA is not None:
return STRUCTURE_DATA
try:
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset"
)
with open(local_path, 'r') as f:
STRUCTURE_DATA = json.load(f)
print(f"Successfully loaded {filename}.")
return STRUCTURE_DATA
except Exception as e:
print(f"CRITICAL ERROR: Could not load {filename} via hf_hub_download: {e}")
return {}
def get_hypotheses_for_selection(model, roi):
"""
Normal ROIs: load from structure.json
Across ROI: load from best_across.json
"""
model = model # assume already mapped
if roi.lower() == "across":
if model in BEST_ACROSS:
return sorted(BEST_ACROSS[model].keys())
return []
else:
# existing STRUCTURE_DATA logic
if not STRUCTURE_DATA:
return []
return sorted(STRUCTURE_DATA.get(model, {}).get(roi, []))
# -------- 2. IMAGE PROCESSING HELPERS --------
def pad_image(pil_img, padding_x=10, padding_y=0, color=(255, 255, 255)):
"""
Adds padding to the sides of a PIL image.
"""
old_width, old_height = pil_img.size
new_width = old_width + 2 * padding_x
new_height = old_height + 2 * padding_y
new_img = Image.new(pil_img.mode, (new_width, new_height), color)
new_img.paste(pil_img, (padding_x, padding_y))
return new_img
def load_images_and_views(model, roi, hyp):
"""
Normal ROI: same as before (load FID & brain from folder)
Across ROI: load single FID and brain as per best_across.json
"""
if roi.lower() == "across":
model_dict = BEST_ACROSS.get(model, {})
hyp_entry = model_dict.get(hyp)
if not hyp_entry:
return [], []
roi_name = hyp_entry["roi"]
fid_file = hyp_entry["fid_file"]
brain_file = hyp_entry.get("brain_file")
ds = load_dataset(HF_REPO, data_dir=f"{model}/{roi_name}/{hyp}", split="train")
fid_image = None
brain_image = None
for item in ds:
img = item["image"]
img_name = Path(img.filename).name
if img_name == fid_file:
fid_image = pad_image(img, padding_x=10, padding_y=0)
elif brain_file and img_name == brain_file:
brain_image = img
return [fid_image] if fid_image else [], [brain_image] if brain_image else []
else:
# normal loading
folder_path = f"{model}/{roi}/{hyp}"
ds = load_dataset(HF_REPO, data_dir=folder_path, split="train")
fid_images, brain_images = [], []
for item in ds:
img = item["image"]
img_name = Path(img.filename).name
if "fid" in img_name.lower():
fid_images.append(pad_image(img, padding_x=10, padding_y=0))
elif "brain" in img_name.lower():
brain_images.append(img)
return fid_images, brain_images
def load_roi_image(roi_name, override_roi=None):
"""
Downloads the ROI PNG from the Hugging Face Hub and loads it as a PIL Image.
If roi_name is 'Across' and override_roi is provided, it loads that ROI instead.
"""
if not roi_name:
return None
# Use override if provided (for Across)
actual_roi = override_roi if override_roi else roi_name
file_name = f"ROI/{actual_roi}.png"
try:
local_path = hf_hub_download(
repo_id=HF_REPO,
filename=file_name,
repo_type="dataset"
)
return Image.open(local_path)
except Exception as e:
print(f"Error loading ROI image {actual_roi} via hf_hub_download. Path used: {file_name}. Details: {e}")
return None
def load_roi_for_selection(model, roi, hyp):
"""
Loads ROI image. If 'Across' is selected, use the original ROI
from BEST_ACROSS for the selected hypothesis.
"""
override_roi = None
if roi.lower() == "across":
override_roi = BEST_ACROSS.get(model, {}).get(hyp, {}).get("roi")
return load_roi_image(roi_name=roi, override_roi=override_roi)
load_structure(HF_REPO)
load_best_across()