Spaces:
Sleeping
Sleeping
| import os | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import ml_collections | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| import app_lib.multimodal as multimodal | |
| from app_lib.ckde import cKDE | |
| from app_lib.config import Config | |
| from app_lib.config import Constants as c | |
| from ibydmt.test import xSKIT | |
| rng = np.random.default_rng() | |
| def _encode_concepts(model, concepts): | |
| concept_features = model.encode_text(concepts) | |
| concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True) | |
| return concept_features.cpu().numpy() | |
| def _encode_image(model, image): | |
| image_features = model.encode_image(image) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy() | |
| def _encode_class_name(model, class_name): | |
| class_text = [f"A photo of a {class_name}"] | |
| class_features = model.encode_text(class_text) | |
| class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True) | |
| return class_features.cpu().numpy() | |
| def _load_embedding(config): | |
| dataset_path = hf_hub_download( | |
| repo_id="jacopoteneggi/IBYDMT", | |
| filename=( | |
| f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet" | |
| ), | |
| repo_type="dataset", | |
| ) | |
| dataset = pd.read_parquet(dataset_path) | |
| return np.array(dataset["embedding"].values.tolist()) | |
| def _sample_random_subset(concept_idx, concepts, cardinality): | |
| sample_idx = list(set(range(len(concepts))) - {concept_idx}) | |
| return rng.permutation(sample_idx)[:cardinality].tolist() | |
| def _test(testing_config, z, concept_idx, concepts, cardinality, sampler, classifier): | |
| def cond_p(z, cond_idx, m): | |
| _, sample_h = sampler.sample(z, cond_idx, m=m) | |
| return sample_h | |
| def f(h): | |
| output = h @ classifier.T | |
| return output.squeeze() | |
| rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], [] | |
| for _ in range(testing_config.r): | |
| subset_idx = _sample_random_subset(concept_idx, concepts, cardinality) | |
| subset = [concepts[idx] for idx in subset_idx] | |
| tester = xSKIT(testing_config) | |
| rejected, tau = tester.test( | |
| z, | |
| concept_idx, | |
| subset_idx, | |
| cond_p, | |
| f, | |
| interrupt_on="max_wealth", | |
| max_wealth=3 * 1 / testing_config.significance_level, | |
| ) | |
| wealth = tester.wealth._wealth | |
| wealth = wealth + [wealth[-1]] * (testing_config.tau_max - len(wealth)) | |
| rejected_hist.append(rejected) | |
| tau_hist.append(tau) | |
| wealth_hist.append(wealth) | |
| subset_hist.append(subset) | |
| return { | |
| "concept": concepts[concept_idx], | |
| "rejected": rejected_hist, | |
| "tau": tau_hist, | |
| "wealth": wealth_hist, | |
| "subset": subset_hist, | |
| } | |
| def get_testing_config(**kwargs): | |
| testing_config = st.session_state.testing_config = ml_collections.ConfigDict() | |
| testing_config.significance_level = kwargs.get("significance_level", 0.05) | |
| testing_config.wealth = kwargs.get("wealth", "ons") | |
| testing_config.bet = kwargs.get("bet", "tanh") | |
| testing_config.kernel = kwargs.get("kernel", "rbf") | |
| testing_config.kernel_scale_method = kwargs.get("kernel_scale_method", "quantile") | |
| testing_config.kernel_scale = kwargs.get("kernel_scale", 0.5) | |
| testing_config.tau_max = kwargs.get("tau_max", 200) | |
| testing_config.r = kwargs.get("r", 10) | |
| return testing_config | |
| def load_precomputed_results(image_name): | |
| results = np.load( | |
| os.path.join("assets", "results", f"{image_name.split('.')[0]}.npy"), | |
| allow_pickle=True, | |
| ).item() | |
| return results | |
| def test( | |
| testing_config, | |
| image, | |
| class_name, | |
| concepts, | |
| cardinality, | |
| dataset_name, | |
| model_name, | |
| device=c.DEVICE, | |
| with_streamlit=True, | |
| ): | |
| config = Config() | |
| config.data.dataset = dataset_name | |
| config.data.backbone = model_name | |
| if with_streamlit: | |
| with st.spinner("Loading model"): | |
| model = multimodal.get_model(config, device=device) | |
| else: | |
| model = multimodal.get_model(config, device=device) | |
| if with_streamlit: | |
| with st.spinner("Encoding concepts"): | |
| cbm = _encode_concepts(model, concepts) | |
| else: | |
| cbm = _encode_concepts(model, concepts) | |
| if with_streamlit: | |
| with st.spinner("Encoding image"): | |
| h = _encode_image(model, image) | |
| else: | |
| h = _encode_image(model, image) | |
| z = h @ cbm.T | |
| z = z.squeeze() | |
| if with_streamlit: | |
| progress_bar = st.progress( | |
| 0, | |
| text=( | |
| "Testing concepts (can take up to a minute) [0 /" | |
| f" {len(concepts)} completed]" | |
| ), | |
| ) | |
| progress_bar.progress( | |
| 1 / (len(concepts) + 1), | |
| text=( | |
| "Testing concepts (can take up to a minute) [0 /" | |
| f" {len(concepts)} completed]" | |
| ), | |
| ) | |
| embedding = _load_embedding(config) | |
| semantics = embedding @ cbm.T | |
| sampler = cKDE(embedding, semantics) | |
| classifier = _encode_class_name(model, class_name) | |
| with ThreadPoolExecutor() as executor: | |
| futures = [ | |
| executor.submit( | |
| _test, | |
| testing_config, | |
| z, | |
| concept_idx, | |
| concepts, | |
| cardinality, | |
| sampler, | |
| classifier, | |
| ) | |
| for concept_idx in range(len(concepts)) | |
| ] | |
| results = [] | |
| for idx, future in enumerate(as_completed(futures)): | |
| results.append(future.result()) | |
| if with_streamlit: | |
| progress_bar.progress( | |
| (idx + 2) / (len(concepts) + 1), | |
| text=( | |
| f"Testing concepts (can take up to a minute) [{idx + 1} /" | |
| f" {len(concepts)} completed]" | |
| ), | |
| ) | |
| rejected = np.empty((testing_config.r, len(concepts))) | |
| tau = np.empty((testing_config.r, len(concepts))) | |
| wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts))) | |
| for _results in results: | |
| concept_idx = concepts.index(_results["concept"]) | |
| rejected[:, concept_idx] = np.array(_results["rejected"]) | |
| tau[:, concept_idx] = np.array(_results["tau"]) | |
| wealth[:, :, concept_idx] = np.array(_results["wealth"]) | |
| tau /= testing_config.tau_max | |
| results = { | |
| "significance_level": testing_config.significance_level, | |
| "concepts": concepts, | |
| "rejected": rejected, | |
| "tau": tau, | |
| "wealth": wealth, | |
| } | |
| return results | |