| """config.py. |
| |
| This module provides a config class to be used for both the parser as well as |
| for providing the model specific classes a way to access the parsed arguments. |
| """ |
| import argparse |
| import logging |
| import os |
| import sys |
| from enum import Enum |
| from pathlib import Path |
| from typing import List, Optional |
|
|
| import regex as re |
| import torch |
| import yaml |
| from datasets import load_dataset, load_from_disk |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[2] |
| sys.path.append(str(PROJECT_ROOT)) |
|
|
|
|
| class ModelSelection(str, Enum): |
| """Enum that contains all possible model choices.""" |
| LLAVA = 'llava' |
| QWEN = 'qwen' |
| CLIP = 'clip' |
| GLAMM = 'glamm' |
| JANUS = 'janus' |
| BLIP2 = 'blip2' |
| MOLMO = 'molmo' |
| PALIGEMMA = 'paligemma' |
| INTERNLM_XC = 'internlm-xcomposer' |
| INTERNVL = 'internvl' |
| MINICPM = 'minicpm' |
| COGVLM = 'cogvlm' |
| PIXTRAL = 'pixtral' |
| AYA_VISION = 'aya-vision' |
| PLM = 'plm' |
|
|
|
|
| class Config: |
| """Config class for both yaml and cli arguments.""" |
|
|
| def __init__(self, |
| architecture: Optional[str] = None, |
| model_path: Optional[str] = None, |
| module: Optional[str] = None, |
| prompt: Optional[str] = None) -> None: |
| """Verifies the passed arguments while populating config fields. |
| |
| Args: |
| architecture (Optional[str]): The model architecture to use. |
| model_path (Optional[str]): The specific model path to use. |
| module (Optional[str]): The specific module to extract embeddings from. |
| prompt (Optional[str]): The prompt to use for models that require it. |
| |
| Raises: |
| ValueError: If any required argument is missing. |
| """ |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| '-c', |
| '--config', |
| type=str, |
| help='' |
| ) |
|
|
| model_sel = [model.value for model in list(ModelSelection)] |
| parser.add_argument( |
| '-a', |
| '--architecture', |
| type=ModelSelection, |
| choices=list(ModelSelection), |
| metavar=f'{model_sel}', |
| default=architecture, |
| help='The model architecture family to extract the embeddings from' |
| ) |
| parser.add_argument( |
| '-m', |
| '--model-path', |
| type=str, |
| default=model_path, |
| help='The specific model path to extract the embeddings from' |
| ) |
| parser.add_argument( |
| '-d', |
| '--debug', |
| default=None, |
| action='store_true', |
| help='Print out debug statements' |
| ) |
| parser.add_argument( |
| '-l', |
| '--log-named-modules', |
| default=None, |
| action='store_true', |
| help='Logs the named modules for the specified model' |
| ) |
| parser.add_argument( |
| '-i', |
| '--input-dir', |
| type=str, |
| help='The specified input directory to read data from' |
| ) |
| parser.add_argument( |
| '-o', |
| '--output-db', |
| type=str, |
| help=( |
| 'The specified output database to save the tensors to, ' |
| 'defaults to embedding.db' |
| ) |
| ) |
| parser.add_argument( |
| '--device', |
| type=str, |
| default='cuda', |
| choices=['cuda', 'cpu'], |
| help='Specify the device to send tensors and the model to' |
| ) |
| parser.add_argument( |
| '--download-path', |
| type=str, |
| help='The path where downloaded models should be stored' |
| ) |
| parser.add_argument( |
| '--pooling-method', |
| type=str, |
| default=None, |
| choices=['mean', 'max'], |
| help='The type of pooling to use for the output embeddings' |
| ) |
|
|
| |
| args = parser.parse_known_args()[0] |
|
|
| |
| |
| config_keys = list(args.__dict__.keys()) |
| config_keys.append('model') |
| config_keys.append('prompt') |
| config_keys.append('modules') |
| config_keys.append('forward') |
| config_keys.append('dataset') |
|
|
| |
| |
| |
| if args.config: |
| with open(args.config, 'r') as file: |
| data = yaml.safe_load(file) |
|
|
| for key in config_keys: |
| if key in data.keys(): |
| setattr(self, key, data[key]) |
|
|
| |
| for key, value in args._get_kwargs(): |
| if value is not None: |
| setattr(self, key, value) |
|
|
| |
| |
| self.debug = ( |
| hasattr(self, 'debug') and self.debug |
| ) |
| if self.debug: |
| logging.getLogger().setLevel(logging.DEBUG) |
| else: |
| logging.getLogger().setLevel(logging.INFO) |
|
|
| |
| assert all( |
| hasattr(self, attr) and getattr(self, attr) is not None |
| for attr in ('architecture', 'model_path') |
| ), ( |
| 'Fields `architecture` and `model_path` in yaml config must exist, ' |
| 'otherwise, --architecture and --model-path must be set' |
| ) |
|
|
| |
| if not isinstance(self.architecture, ModelSelection): |
| assert self.architecture in model_sel, ( |
| f'Architecture {self.architecture} not supported, ' |
| f'use one of {model_sel}' |
| ) |
| self.architecture = ModelSelection(self.architecture) |
|
|
| |
| if hasattr(self, 'model'): |
| model_mapping = {} |
| for mapping in self.model: |
| model_mapping = {**model_mapping, **mapping} |
| self.model = model_mapping |
|
|
| |
| if hasattr(self, 'forward'): |
| forward_mapping = {} |
| for mapping in self.forward: |
| forward_mapping = {**forward_mapping, **mapping} |
| self.forward = forward_mapping |
|
|
| |
| self.log_named_modules = ( |
| hasattr(self, 'log_named_modules') and self.log_named_modules |
| ) |
| if self.log_named_modules: |
| return |
|
|
| |
| if module is not None: |
| self.modules = [module] |
| assert hasattr(self, 'modules') and self.modules is not None, ( |
| 'Must declare at least one module.' |
| ) |
| self.set_modules(self.modules) |
|
|
| |
| if hasattr(self, 'dataset') and hasattr(self, 'input_dir'): |
| raise ValueError( |
| 'Only one of `dataset` or `input_dir` can be set, ' |
| 'not both. Please choose one.' |
| ) |
|
|
| self.image_paths = [] |
| if hasattr(self, 'dataset'): |
| |
| ds_mapping = {} |
| for mapping in self.dataset: |
| ds_mapping = {**ds_mapping, **mapping} |
|
|
| dataset_path = ds_mapping.get('dataset_path', None) |
| local_dataset_path = ds_mapping.get('local_dataset_path', None) |
|
|
| |
| assert ((dataset_path and not local_dataset_path) or |
| (not dataset_path and local_dataset_path)), ( |
| 'One of `dataset_path` (for hosted datasets) or `local_dataset_path` (for local datasets)' |
| 'must be set.' |
| ) |
|
|
| dataset = None |
| dataset_split = ds_mapping.get('dataset_split', None) |
| if dataset_path: |
| |
| logging.debug(f'Loading dataset from {dataset_path} with split={dataset_split}...') |
| dataset = load_dataset(dataset_path) |
|
|
| elif local_dataset_path: |
| |
| logging.debug(f'Loading dataset from {local_dataset_path} with split={dataset_split}...') |
| dataset = load_from_disk(local_dataset_path) |
|
|
| dataset = dataset[dataset_split] if dataset_split else dataset |
|
|
| |
| img_dir = ds_mapping.get('image_dataset_path', None) |
| if img_dir: |
| logging.debug( |
| f'Locating image dataset from {img_dir}...') |
|
|
| |
| dataset = dataset.map( |
| lambda row: {'image': os.path.join(img_dir, row['image'])}) |
|
|
| self.image_paths = dataset['image'] |
|
|
| self.dataset = dataset |
|
|
| else: |
| self.dataset = None |
| self.set_image_paths(self.input_dir |
| if hasattr(self, 'input_dir') else |
| None) |
| |
| if prompt is not None: |
| self.prompt = prompt |
| |
| if not (self.dataset or self.has_images() or hasattr(self, 'prompt')): |
| raise ValueError( |
| 'Input directory was either not provided or empty ' |
| 'and no prompt was provided' |
| ) |
|
|
| |
| |
| if 'cuda' in self.device and not torch.cuda.is_available(): |
| raise ValueError('Device set to cuda but no GPU found for this machine') |
|
|
| self.device = torch.device(self.device) |
| self.DB_TABLE_NAME = 'tensors' |
| self.NO_IMG_PROMPT = 'No image prompt' |
|
|
| |
| if not hasattr(self, 'output_db'): |
| self.output_db = 'embeddings.db' |
|
|
| def has_images(self) -> bool: |
| """Returns a boolean for whether or not the input directory has images. |
| |
| Returns: |
| bool: Whether or not the input directory has images. |
| """ |
| if not self.dataset: |
| return len(self.image_paths) > 0 |
| else: |
| return 'image' in self.dataset.column_names |
|
|
| def matches_module(self, module_name: str) -> bool: |
| """Returns whether the given module name matches one of the regexes. |
| |
| Args: |
| module_name (str): The module name to match. |
| |
| Returns: |
| bool: Whether the given module name matches the config's module |
| regexes. |
| """ |
| for module in self.modules: |
| if module.fullmatch(module_name): |
| return True |
| return False |
|
|
| def set_prompt(self, prompt: str) -> None: |
| """Sets the prompt for the specific config. |
| |
| Args: |
| prompt (str): Prompt to set. |
| """ |
| self.prompt = prompt |
|
|
| def set_modules(self, to_match_modules: List[str]) -> None: |
| """Sets the modules for the specific config. |
| |
| Args: |
| to_match_modules (List[str]): The module regexes to match. |
| """ |
| self.modules = [re.compile(module) for module in to_match_modules] |
|
|
| def set_image_paths(self, input_dir: Optional[str]) -> None: |
| """Sets the images based on the input directory. |
| |
| Args: |
| input_dir (Optional[str]): The input directory. |
| """ |
| if input_dir is None: |
| return |
| |
| |
| image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] |
| self.image_paths = [ |
| os.path.join(root, file_path) |
| for root, _, files in os.walk(input_dir) |
| for file_path in files |
| if os.path.splitext(file_path)[1].lower() in image_exts |
| ] |
|
|