Spaces:
Running
Running
| import json | |
| import os | |
| import shutil | |
| import re | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| from urllib.parse import urlparse | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Literal, Tuple, Union | |
| from huggingface_hub import HfApi, HfFileSystem, hf_hub_url, get_hf_file_metadata | |
| from huggingface_hub import ModelCard | |
| from huggingface_hub.hf_api import ModelInfo | |
| from transformers import AutoConfig | |
| from transformers.models.auto.tokenization_auto import AutoTokenizer | |
| from src.envs import EVAL_REQUESTS_SUBGRAPH, EVAL_REQUESTS_CAUSALGRAPH | |
| TASKS = ["ioi", "mcqa", "arithmetic-addition", "arithmetic-subtraction", "arc-easy", "arc-challenge"] | |
| MODELS = ["gpt2", "qwen2.5", "gemma2", "llama3", "interpbench"] | |
| class FeaturizerValidator: | |
| def __init__(self, base_featurizer_class): | |
| self.base_featurizer_class = base_featurizer_class | |
| self.featurizer_class_name = None | |
| # torch.nn.Module | |
| self.module_value, self.module_attr = "torch", "Module" | |
| self.featurizer_module_class_name_1 = None | |
| self.featurizer_module_class_name_2 = None | |
| def find_featurizer_subclass(self, module_path: str) -> Tuple[bool, Union[str, None]]: | |
| """ | |
| Finds the first class in the module that inherits from Featurizer. | |
| Args: | |
| module_path: Path to the uploaded Python file | |
| Returns: | |
| Tuple of (success, class_name, message) | |
| """ | |
| # First try with AST for safety | |
| try: | |
| with open(module_path, 'r') as file: | |
| tree = ast.parse(file.read(), filename=module_path) | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ClassDef): | |
| for base in node.bases: | |
| if isinstance(base, ast.Name) and base.id == self.base_featurizer_class.__name__: | |
| return True, node.name, f"Found class '{node.name}' that inherits from {self.base_featurizer_class.__name__}" | |
| return False, None, f"No class inheriting from {self.base_featurizer_class.__name__} found" | |
| except Exception as e: | |
| return False, None, f"Error during static analysis: {str(e)}" | |
| def find_featurizer_module_classes(self, module_path: str) -> Tuple[bool, Union[str, None]]: | |
| try: | |
| with open(module_path, 'r') as file: | |
| tree = ast.parse(file.read(), filename=module_path) | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ClassDef): | |
| for base in node.bases: | |
| if (isinstance(base, ast.Attribute) and base.attr == self.module_attr): | |
| if self.featurizer_module_class_name_1 is None: | |
| self.featurizer_module_class_name_1 = node.name | |
| else: | |
| self.featurizer_module_class_name_2 = node.name | |
| return True, f"Found two featurizer modules: {self.featurizer_module_class_name_1}, {self.featurizer_module_class_name_2}" | |
| if self.featurizer_module_class_name_1: | |
| return True, f"Found one featurizer module: {self.featurizer_module_class_name_1}" | |
| return False, f"Found no featurizer modules." | |
| except Exception as e: | |
| return False, f"Error during static analysis: {e}" | |
| def validate_uploaded_module(self, module_path: str) -> Tuple[bool, str]: | |
| """ | |
| Validates an uploaded module to ensure it properly extends the Featurizer class. | |
| Args: | |
| module_path: Path to the uploaded Python file | |
| class_name: Name of the class to validate | |
| Returns: | |
| Tuple of (is_valid, message) | |
| """ | |
| # First, find the name of the featurizer class we're verifying | |
| found, class_name, message = self.find_featurizer_subclass(module_path) | |
| if not found: | |
| return False, message | |
| else: | |
| print("Verified featurizer subclass.") | |
| # Second, find the name of the featurizer and inverse featurizer modules | |
| modules_found, modules_message = self.find_featurizer_module_classes(module_path) | |
| if not modules_found: | |
| return False, modules_message | |
| else: | |
| print(f"Verified featurizer module(s): {modules_message}") | |
| # Then, perform static code analysis on the featurizer class for basic safety | |
| inheritance_check, ast_message = self._verify_inheritance_with_ast(module_path, class_name) | |
| if not inheritance_check: | |
| return False, ast_message | |
| # Then, try to load and validate the featurizer class | |
| return self._verify_inheritance_with_import(module_path, class_name) | |
| # TODO: try directly loading featurizer module and inverse featurizer module? | |
| def _verify_inheritance_with_ast(self, module_path: str, class_name: str) -> Tuple[bool, str]: | |
| """Verify inheritance using AST without executing code""" | |
| try: | |
| with open(module_path, 'r') as file: | |
| tree = ast.parse(file.read(), filename=module_path) | |
| # Look for class definitions that match the target class name | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ClassDef) and node.name == class_name: | |
| # Check if any base class name matches 'Featurizer' | |
| for base in node.bases: | |
| if isinstance(base, ast.Name) and base.id == self.base_featurizer_class.__name__: | |
| return True, "Static analysis indicates proper inheritance" | |
| return False, f"Class '{class_name}' does not appear to inherit from {self.base_featurizer_class.__name__}" | |
| return False, f"Class '{class_name}' not found in the uploaded module" | |
| except Exception as e: | |
| return False, f"Error during static analysis: {str(e)}" | |
| def _verify_inheritance_with_import(self, module_path: str, class_name: str) -> Tuple[bool, str]: | |
| """Safely import the module and verify inheritance using Python's introspection""" | |
| try: | |
| # Dynamically import the module | |
| spec = importlib.util.spec_from_file_location("uploaded_module", module_path) | |
| if spec is None or spec.loader is None: | |
| return False, "Could not load the module specification" | |
| uploaded_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(uploaded_module) | |
| # Get the class from the module | |
| if not hasattr(uploaded_module, class_name): | |
| return False, f"Class '{class_name}' not found in the uploaded module" | |
| uploaded_class = getattr(uploaded_module, class_name) | |
| # Check if it's a proper subclass | |
| if not inspect.isclass(uploaded_class): | |
| return False, f"'{class_name}' is not a class" | |
| if not issubclass(uploaded_class, self.base_featurizer_class): | |
| return False, f"'{class_name}' does not inherit from {self.base_featurizer_class.__name__}" | |
| # Optional: Check method resolution order | |
| mro = inspect.getmro(uploaded_class) | |
| if self.base_featurizer_class not in mro: | |
| return False, f"{self.base_featurizer_class.__name__} not in the method resolution order" | |
| return True, f"Class '{class_name}' properly extends {self.base_featurizer_class.__name__}" | |
| except Exception as e: | |
| return False, f"Error during dynamic validation: {str(e)}" | |
| def is_model_on_hub(model_name: str, revision: str, token: str = None, trust_remote_code=False, test_tokenizer=False) -> tuple[bool, str]: | |
| """Checks if the model model_name is on the hub, and whether it (and its tokenizer) can be loaded with AutoClasses.""" | |
| try: | |
| config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token) | |
| if test_tokenizer: | |
| try: | |
| tk = AutoTokenizer.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token) | |
| except ValueError as e: | |
| return ( | |
| False, | |
| f"uses a tokenizer which is not in a transformers release: {e}", | |
| None | |
| ) | |
| except Exception as e: | |
| return (False, "'s tokenizer cannot be loaded. Is your tokenizer class in a stable transformers release, and correctly configured?", None) | |
| return True, None, config | |
| except ValueError: | |
| return ( | |
| False, | |
| "needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.", | |
| None | |
| ) | |
| except Exception as e: | |
| return False, "was not found on hub!", None | |
| def get_model_size(model_info: ModelInfo, precision: str): | |
| """Gets the model size from the configuration, or the model name if the configuration does not contain the information.""" | |
| try: | |
| model_size = round(model_info.safetensors["total"] / 1e9, 3) | |
| except (AttributeError, TypeError): | |
| return 0 # Unknown model sizes are indicated as 0, see NUMERIC_INTERVALS in app.py | |
| size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.modelId.lower()) else 1 | |
| model_size = size_factor * model_size | |
| return model_size | |
| def get_model_arch(model_info: ModelInfo): | |
| """Gets the model architecture from the configuration""" | |
| return model_info.config.get("architectures", "Unknown") | |
| def already_submitted_models(requested_models_dir: str) -> set[str]: | |
| """Gather a list of already submitted models to avoid duplicates""" | |
| depth = 1 | |
| file_names = [] | |
| users_to_submission_dates = defaultdict(list) | |
| for root, _, files in os.walk(requested_models_dir): | |
| current_depth = root.count(os.sep) - requested_models_dir.count(os.sep) | |
| if current_depth == depth: | |
| for file in files: | |
| if not file.endswith(".json"): | |
| continue | |
| with open(os.path.join(root, file), "r") as f: | |
| info = json.load(f) | |
| file_names.append(f"{info['model']}_{info['revision']}_{info['track']}") | |
| # Select organisation | |
| if info["model"].count("/") == 0 or "submitted_time" not in info: | |
| continue | |
| organisation, _ = info["model"].split("/") | |
| users_to_submission_dates[organisation].append(info["submitted_time"]) | |
| return set(file_names), users_to_submission_dates | |
| def _format_time(earliest_time): | |
| time_left = (earliest_time.tz_convert("UTC") + timedelta(weeks=1)) - pd.Timestamp.utcnow() | |
| hours = time_left.seconds // 3600 | |
| minutes, seconds = divmod(time_left.seconds % 3600, 60) | |
| time_left_formatted = f"{hours:02}:{minutes:02}:{seconds:02}" | |
| if time_left.days > 0: | |
| time_left_formatted = f"{time_left.days} days, {time_left_formatted}" | |
| return time_left_formatted | |
| def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]: | |
| """Creates the different dataframes for the evaluation queues requests""" | |
| entries = [entry for entry in os.listdir(save_path) if not entry.startswith(".")] | |
| all_evals = [] | |
| for entry in entries: | |
| if ".json" in entry: | |
| file_path = os.path.join(save_path, entry) | |
| with open(file_path) as fp: | |
| data = json.load(fp) | |
| # if "still_on_hub" in data and data["still_on_hub"]: | |
| # data[EvalQueueColumn.model.name] = make_clickable_model(data["hf_repo"], data["model"]) | |
| # data[EvalQueueColumn.revision.name] = data.get("revision", "main") | |
| # else: | |
| # data[EvalQueueColumn.model.name] = data["model"] | |
| # data[EvalQueueColumn.revision.name] = "N/A" | |
| all_evals.append(data) | |
| elif ".md" not in entry: | |
| # this is a folder | |
| sub_entries = [e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")] | |
| for sub_entry in sub_entries: | |
| file_path = os.path.join(save_path, entry, sub_entry) | |
| with open(file_path) as fp: | |
| data = json.load(fp) | |
| all_evals.append(data) | |
| return pd.DataFrame(all_evals) | |
| def check_rate_limit(track, user_name, contact_email): | |
| if "Circuit" in track: | |
| save_path = EVAL_REQUESTS_SUBGRAPH | |
| else: | |
| save_path = EVAL_REQUESTS_CAUSALGRAPH | |
| evaluation_queue = get_evaluation_queue_df(save_path, ["user_name", "contact_email"]) | |
| if evaluation_queue.empty or user_name == "atticusg" or user_name == "yiksiu": | |
| return True, None | |
| one_week_ago = pd.Timestamp.utcnow() - timedelta(weeks=1) | |
| user_name_occurrences = evaluation_queue[evaluation_queue["user_name"] == user_name] | |
| user_name_occurrences["submit_time"] = pd.to_datetime(user_name_occurrences["submit_time"], utc=True) | |
| user_name_occurrences = user_name_occurrences[user_name_occurrences["submit_time"] >= one_week_ago] | |
| email_occurrences = evaluation_queue[evaluation_queue["contact_email"] == contact_email.lower()] | |
| email_occurrences["submit_time"] = pd.to_datetime(email_occurrences["submit_time"], utc=True) | |
| email_occurrences = email_occurrences[email_occurrences["submit_time"] >= one_week_ago] | |
| if user_name_occurrences.shape[0] >= 2: | |
| earliest_time = user_name_occurrences["submit_time"].min() | |
| time_left_formatted = _format_time(earliest_time) | |
| return False, time_left_formatted | |
| if email_occurrences.shape[0] >= 2: | |
| earliest_time = email_occurrences["submit_time"].min() | |
| time_left_formatted = _format_time(earliest_time) | |
| return False, time_left_formatted | |
| return True, None | |
| def parse_huggingface_url(url: str): | |
| """ | |
| Extracts repo_id and subfolder path from a Hugging Face URL. | |
| Returns (repo_id, folder_path). | |
| """ | |
| # Handle cases where the input is already a repo_id (no URL) | |
| if not url.startswith(("http://", "https://")): | |
| return url, None | |
| parsed = urlparse(url) | |
| path_parts = parsed.path.strip("/").split("/") | |
| revision = "main" | |
| # Extract repo_id (username/repo_name) | |
| if len(path_parts) < 2: | |
| return None, None, None # Can't extract repo_id | |
| else: | |
| repo_id = f"{path_parts[0]}/{path_parts[1]}" | |
| # Extract folder path (if in /tree/ or /blob/) | |
| if "tree" in path_parts or "blob" in path_parts: | |
| try: | |
| branch_idx = path_parts.index("tree") if "tree" in path_parts else path_parts.index("blob") | |
| folder_path = "/".join(path_parts[branch_idx + 2:]) # Skip "tree/main" or "blob/main" | |
| revision = path_parts[branch_idx + 1] | |
| except (ValueError, IndexError): | |
| folder_path = None | |
| else: | |
| folder_path = None | |
| return repo_id, folder_path, revision | |
| def validate_directory_circuit(fs: HfFileSystem, repo_id: str, dirname: str, curr_tm: str, circuit_level:Literal['edge', 'node','neuron']='edge'): | |
| errors = [] | |
| warnings = [] | |
| task, model = curr_tm.split("_") | |
| curr_tm_display = curr_tm.replace("_", "/") | |
| files = fs.ls(dirname) | |
| # Detect whether multi-circuit or importances | |
| is_multiple_circuits = False | |
| files = [f["name"] for f in files if (f["name"].endswith(".json") or f["name"].endswith(".pt"))] | |
| if len(files) == 1: | |
| is_multiple_circuits = False | |
| elif len(files) > 1: | |
| is_multiple_circuits = True | |
| if len(files) < 9: | |
| errors.append(f"Folder for {curr_tm_display} contains multiple circuits, but not enough. If you intended to submit importances, include only one circuit in the folder. Otherwise, please add the rest of the circuits.") | |
| else: | |
| warnings.append(f"Directory present for {curr_tm_display} but is empty") | |
| offset = 0 | |
| for idx, file in enumerate(files): | |
| file_suffix = file.split(repo_id + "/")[1] | |
| file_url = hf_hub_url(repo_id=repo_id, filename=file_suffix) | |
| file_info = get_hf_file_metadata(file_url) | |
| file_size_mb = file_info.size / (1024 * 1024) | |
| if file_size_mb > 50: | |
| warnings.append(f"Will skip file >50MB: {file}") | |
| offset -= 1 | |
| continue | |
| if is_multiple_circuits and idx + offset >= 9: | |
| break | |
| return errors, warnings | |
| def verify_circuit_submission(hf_repo, level, progress=gr.Progress()): | |
| VALID_COMBINATIONS = [ | |
| "ioi_gpt2", "ioi_qwen2.5", "ioi_gemma2", "ioi_llama3", "ioi_interpbench", | |
| "mcqa_qwen2.5", "mcqa_gemma2", "mcqa_llama3", | |
| "arithmetic-addition_llama3", "arithmetic-subtraction_llama3", | |
| "arc-easy_gemma2", "arc-easy_llama3", | |
| "arc-challenge_llama3" | |
| ] | |
| errors = [] | |
| warnings = [] | |
| directories_present = {tm: False for tm in VALID_COMBINATIONS} | |
| directories_valid = {tm: False for tm in VALID_COMBINATIONS} | |
| fs = HfFileSystem() | |
| path = hf_repo | |
| level = level | |
| try: | |
| repo_id, folder_path, revision = parse_huggingface_url(hf_repo) | |
| folder_path = repo_id + "/" + folder_path | |
| files = fs.listdir(folder_path, revision=revision) | |
| except Exception as e: | |
| errors.append(f"Could not open Huggingface URL: {e}") | |
| return errors, warnings | |
| def _process_directory(files, current_path="", file_counts=0): | |
| """Recursively process directories, handling abs/True/False subdirectories""" | |
| nonlocal errors, warnings, directories_present, directories_valid | |
| for dirname in progress.tqdm(files, desc=f"Validating directories in {current_path or 'repo'}"): | |
| file_counts += 1 | |
| if file_counts >= 30: | |
| warnings.append("Folder contains many files/directories; stopped at 30.") | |
| break | |
| circuit_dir = dirname["name"] | |
| dirname_proc = circuit_dir.lower().split("/")[-1] | |
| if not fs.isdir(circuit_dir): | |
| continue | |
| # Check if this directory contains "abs" and "True"/"False" | |
| if "abs" in dirname_proc and ("true" in dirname_proc or "false" in dirname_proc): | |
| try: | |
| # Recurse into this directory | |
| subdirs = fs.listdir(circuit_dir, revision=revision if 'revision' in locals() else None) | |
| _process_directory(subdirs, circuit_dir, file_counts) | |
| except Exception as e: | |
| warnings.append(f"Could not access subdirectory {circuit_dir}: {e}") | |
| continue | |
| curr_task = None | |
| curr_model = None | |
| # Look for task names in filename | |
| for task in TASKS: | |
| if dirname_proc.startswith(task) or f"_{task}" in dirname_proc: | |
| curr_task = task | |
| break | |
| # Look for model names in filename | |
| for model in MODELS: | |
| if dirname_proc.startswith(model) or f"_{model}" in dirname_proc: | |
| curr_model = model | |
| break | |
| if curr_task is not None and curr_model is not None: | |
| curr_tm = f"{curr_task}_{curr_model}" | |
| if curr_tm in VALID_COMBINATIONS: | |
| directories_present[curr_tm] = True | |
| else: | |
| continue | |
| else: | |
| continue | |
| # Parse circuits directory | |
| print(f"validating {circuit_dir}") | |
| vd_errors, vd_warnings = validate_directory_circuit(fs, repo_id, circuit_dir, curr_tm, level) | |
| errors.extend(vd_errors) | |
| warnings.extend(vd_warnings) | |
| if len(vd_errors) == 0: | |
| directories_valid[curr_tm] = True | |
| # Start the recursive processing | |
| _process_directory(files) | |
| task_set, model_set = set(), set() | |
| for tm in directories_present: | |
| if not directories_present[tm]: | |
| continue | |
| if not directories_valid[tm]: | |
| warnings.append(f"Directory found for {tm.replace('_', '/')}, but circuits not valid or present") | |
| continue | |
| task, model = tm.split("_") | |
| task_set.add(task) | |
| model_set.add(model) | |
| if len(task_set) < 2: | |
| errors.append("At least 2 tasks are required") | |
| if len(model_set) < 2: | |
| errors.append("At least 2 models are required") | |
| no_tm_display = [tm.replace("_", "/") for tm in directories_valid if not directories_valid[tm]] | |
| if len(no_tm_display) > 0: | |
| warnings.append(f"No valid circuits or importance scores found for the following tasks/models: {*no_tm_display,}") | |
| return errors, warnings | |
| def validate_directory_causalgraph(fs: HfFileSystem, repo_id: str, dirname: str): | |
| errors = [] | |
| warnings = [] | |
| files = fs.ls(dirname) | |
| files = [f["name"] for f in files if "_featurizer" in f["name"] or "_indices" in f["name"]] | |
| valid_triplet = False | |
| offset = 0 | |
| for idx, file in enumerate(files): | |
| file_suffix = file.split(repo_id + "/")[1] | |
| file_url = hf_hub_url(repo_id=repo_id, filename=file_suffix) | |
| file_info = get_hf_file_metadata(file_url) | |
| file_size_mb = file_info.size / (1024 * 1024) | |
| if file_size_mb > 50: | |
| warnings.append(f"Will skip file >50MB: {file}") | |
| offset -= 1 | |
| continue | |
| if idx + offset > 30: | |
| warnings.append("Many files in directory; stopping at 30") | |
| break | |
| if file.endswith("_featurizer") or file.endswith("_indices"): | |
| prefix = "_".join(file.split("_")[:-1]) | |
| this_suffix = "_" + file.split("_")[-1] | |
| suffixes = ("_featurizer", "_inverse_featurizer", "_indices") | |
| for idx, suffix in enumerate(suffixes): | |
| if file.replace(this_suffix, suffix) not in files: | |
| warnings.append(f"For {prefix}, found a {this_suffix} file but no associated {suffix}") | |
| break | |
| if idx == len(suffixes) - 1: | |
| valid_triplet = True | |
| if valid_triplet: | |
| found_submodule = False | |
| found_layer = False | |
| found_token = False | |
| if "residual" or "attention" in prefix.lower(): | |
| found_submodule = True | |
| if "layer" in prefix.lower(): | |
| found_layer = True | |
| if "token" in prefix.lower(): | |
| found_token = True | |
| if not found_submodule or not found_layer or not found_token: | |
| errors.append("Could not derive where featurizer should be applied from featurizer filenames.") | |
| if valid_triplet: | |
| break | |
| if not valid_triplet: | |
| errors.append("No valid featurizer/inverse featurizer/indices triplets.") | |
| return errors, warnings | |
| def verify_causal_variable_submission(hf_repo, progress=gr.Progress()): | |
| CV_TASKS = set(["ioi_task", "4_answer_MCQA", "ARC_easy", "arithmetic", "ravel_task"]) | |
| CV_TASK_VARIABLES = {"ioi_task": ["output_token", "output_position"], | |
| "4_answer_MCQA": ["answer_pointer", "answer"], | |
| "ARC_easy": ["answer_pointer", "answer"], | |
| "arithmetic": ["ones_carry"], | |
| "ravel_task": ["Country", "Continent", "Language"]} | |
| CV_MODELS = set(["GPT2LMHeadModel", "Qwen2ForCausalLM", "Gemma2ForCausalLM", "LlamaForCausalLM"]) | |
| # create pairs of valid task/model combinations | |
| CV_VALID_TASK_MODELS = set([("ioi_task", "GPT2LMHeadModel"), | |
| ("ioi_task", "Qwen2ForCausalLM"), | |
| ("ioi_task", "Gemma2ForCausalLM"), | |
| ("ioi_task", "LlamaForCausalLM"), | |
| ("4_answer_MCQA", "Qwen2ForCausalLM"), | |
| ("4_answer_MCQA", "Gemma2ForCausalLM"), | |
| ("4_answer_MCQA", "LlamaForCausalLM"), | |
| ("ARC_easy", "Gemma2ForCausalLM"), | |
| ("ARC_easy", "LlamaForCausalLM"), | |
| ("arithmetic", "Gemma2ForCausalLM"), | |
| ("arithmetic", "LlamaForCausalLM"), | |
| ("ravel_task", "Gemma2ForCausalLM"), | |
| ("ravel_task", "LlamaForCausalLM")]) | |
| errors = [] | |
| warnings = [] | |
| num_py_files = 0 | |
| directories_present = {tm: False for tm in CV_VALID_TASK_MODELS} | |
| directories_valid = {tm: False for tm in CV_VALID_TASK_MODELS} | |
| variables_valid = {} | |
| fs = HfFileSystem() | |
| path = hf_repo | |
| try: | |
| repo_id, folder_path, revision = parse_huggingface_url(hf_repo) | |
| folder_path = repo_id + "/" + folder_path | |
| files = fs.listdir(folder_path, revision=revision) | |
| except Exception as e: | |
| errors.append(f"Could not open Huggingface URL: {e}") | |
| return errors, warnings | |
| file_counts = 0 | |
| for file in progress.tqdm(files, desc="Validating files in repo"): | |
| filename = file["name"] | |
| file_counts += 1 | |
| if file_counts >= 30: | |
| warnings.append("Folder contains many files/directories; stopped at 30.") | |
| break | |
| if filename.endswith(".py"): | |
| num_py_files += 1 | |
| causalgraph_dir = filename | |
| dirname_proc = causalgraph_dir.lower().split("/")[-1] | |
| if not fs.isdir(causalgraph_dir): | |
| continue | |
| curr_task = None | |
| curr_model = None | |
| curr_variable = None | |
| # Look for task names in filename | |
| for task in CV_TASKS: | |
| if dirname_proc.startswith(task.lower()) or f"_{task.lower()}" in dirname_proc: | |
| curr_task = task | |
| if curr_task not in variables_valid: | |
| variables_valid[curr_task] = {v: False for v in CV_TASK_VARIABLES[curr_task]} | |
| for variable in CV_TASK_VARIABLES[curr_task]: | |
| if dirname_proc.startswith(variable.lower()) or f"_{variable.lower()}" in dirname_proc or f"_{variable.lower().replace('_', '-')}" in dirname_proc: | |
| curr_variable = variable | |
| break | |
| # Look for model names in filename | |
| for model in CV_MODELS: | |
| if dirname_proc.startswith(model.lower()) or f"_{model.lower()}" in dirname_proc: | |
| curr_model = model | |
| if curr_task is not None and curr_model is not None and curr_variable is not None: | |
| curr_tm = (curr_task, curr_model) | |
| if curr_tm in CV_VALID_TASK_MODELS: | |
| directories_present[curr_tm] = True | |
| else: | |
| continue | |
| else: | |
| continue | |
| print(f"validating {causalgraph_dir}") | |
| vd_errors, vd_warnings = validate_directory_causalgraph(fs, repo_id, causalgraph_dir) | |
| errors.extend(vd_errors) | |
| warnings.extend(vd_warnings) | |
| if len(vd_errors) == 0: | |
| directories_valid[curr_tm] = True | |
| variables_valid[curr_task][curr_variable] = True | |
| if num_py_files == 0: | |
| warnings.append("No featurizer.py or token_position.py files detected in root of provided repo. We will load from the code used for baseline evaluations.") | |
| elif num_py_files == 1: | |
| warnings.append("Either featurizer.py or token_position.py files missing in root of provided repo. We will load from the code used for baseline evaluations.") | |
| task_set, model_set = set(), set() | |
| for tm in directories_present: | |
| if not directories_present[tm]: | |
| continue | |
| if not directories_valid[tm]: | |
| warnings.append(f"Directory found for {tm[0]}/{tm[1]}, but contents not valid") | |
| continue | |
| for tm in directories_valid: | |
| if directories_valid[tm]: | |
| task, model = tm | |
| task_set.add(task) | |
| model_set.add(model) | |
| if len(task_set) == 0 or len(model_set) == 0: | |
| errors.append("No valid directories found for any task/model.") | |
| # no_tm_display = [f"{tm[0]}/{tm[1]}" for tm in directories_valid if not directories_valid[tm]] | |
| # if len(no_tm_display) > 0: | |
| # warnings.append(f"No valid submission found for the following tasks/models: {*no_tm_display,}") | |
| for task in variables_valid: | |
| found_variable_display = [v for v in variables_valid[task] if variables_valid[task][v]] | |
| no_variable_display = [v for v in variables_valid[task] if not variables_valid[task][v]] | |
| if no_variable_display: | |
| warnings.append(f"For {task}, found variables {*found_variable_display,}, but not variables {*no_variable_display,}") | |
| return errors, warnings |