|
|
"""base.py. |
|
|
|
|
|
Provides the common classes used such as the ModelSelection enum as well as the |
|
|
abstract base class for models. |
|
|
""" |
|
|
import io |
|
|
import logging |
|
|
import os |
|
|
import sqlite3 |
|
|
from abc import ABC, abstractmethod |
|
|
from collections.abc import Iterator |
|
|
from typing import Callable, List, Optional, TypedDict |
|
|
|
|
|
import torch |
|
|
import tqdm |
|
|
from PIL import Image |
|
|
from transformers import AutoProcessor |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
from .config import Config |
|
|
|
|
|
|
|
|
class ModelInput(TypedDict): |
|
|
"""Definition for the general model input dictionary.""" |
|
|
image: str | Image.Image |
|
|
prompt: str |
|
|
label: Optional[str] |
|
|
data: BatchFeature |
|
|
row_id: Optional[str] |
|
|
|
|
|
|
|
|
class ModelBase(ABC): |
|
|
"""Provides an abstract base class for everything to implement.""" |
|
|
|
|
|
def __init__(self, config: Config) -> None: |
|
|
"""Initialization of the model base class. |
|
|
|
|
|
Args: |
|
|
config (Config): Parsed config. |
|
|
""" |
|
|
self.model_path = config.model_path |
|
|
self.config = config |
|
|
|
|
|
|
|
|
if self.config.log_named_modules: |
|
|
self._log_named_modules() |
|
|
exit(0) |
|
|
|
|
|
|
|
|
logging.debug( |
|
|
f'Loading model {self.config.architecture.value}; {self.model_path}' |
|
|
) |
|
|
self._load_specific_model() |
|
|
|
|
|
|
|
|
self._init_processor() |
|
|
|
|
|
def _log_named_modules(self) -> None: |
|
|
"""Logs the named modules based on the loaded model.""" |
|
|
file_path = 'logs/' + self.model_path + '.txt' |
|
|
directory_path = os.path.dirname(file_path) |
|
|
|
|
|
|
|
|
if os.path.isfile(file_path): |
|
|
logging.debug(f'Named modules are cached in {file_path}') |
|
|
return |
|
|
|
|
|
|
|
|
self._load_specific_model() |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(directory_path): |
|
|
os.makedirs(directory_path) |
|
|
|
|
|
with open(file_path, 'w') as output_file: |
|
|
output_file.writelines( |
|
|
[f'{name}\n' for name, _ in self.model.named_modules()] |
|
|
) |
|
|
|
|
|
@abstractmethod |
|
|
def _load_specific_model(self) -> None: |
|
|
"""Abstract method that loads the specific model.""" |
|
|
pass |
|
|
|
|
|
def _init_processor(self) -> None: |
|
|
"""Initialize the self.processor by loading from the path.""" |
|
|
self.processor = AutoProcessor.from_pretrained(self.model_path) |
|
|
|
|
|
def _generate_state_hook(self, |
|
|
name: str, |
|
|
model_input: ModelInput |
|
|
) -> Callable[[torch.nn.Module, tuple, torch.Tensor], None]: |
|
|
"""Generates the state hook depending on the embedding type. |
|
|
|
|
|
Args: |
|
|
name (str): The module name. |
|
|
model_input (ModelInput): The input dictionary |
|
|
containing the image path, prompt, label (if applicable) and |
|
|
the data itself. |
|
|
|
|
|
Returns: |
|
|
hook function: The hook function to return. |
|
|
""" |
|
|
image_path, prompt = model_input['image'], model_input['prompt'] |
|
|
label = model_input.get('label', None) |
|
|
row_id = model_input.get('row_id', None) |
|
|
|
|
|
|
|
|
if isinstance(image_path, str) and image_path != self.config.NO_IMG_PROMPT: |
|
|
image_path = os.path.abspath(image_path) |
|
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(image_path) |
|
|
|
|
|
def generate_states_hook(module: torch.nn.Module, input: tuple, output: torch.Tensor) -> None: |
|
|
"""Hook handle function that saves the embedding output to a tensor. |
|
|
|
|
|
This tensor will be saved within a SQL database, according to the |
|
|
connection that was initialized previously. |
|
|
|
|
|
Args: |
|
|
module (torch.nn.Module): The module that save its hook on. |
|
|
input (tuple): The input used. |
|
|
output (torch.Tensor): The embeddings to save. |
|
|
""" |
|
|
if not isinstance(output, torch.Tensor): |
|
|
logging.warning(f'Output type of {str(type(module))} is not a tensor, skipped.') |
|
|
return |
|
|
|
|
|
cursor = self.connection.cursor() |
|
|
|
|
|
|
|
|
tensor_blob = io.BytesIO() |
|
|
|
|
|
|
|
|
|
|
|
final_output = getattr(output, self.config.pooling_method)(dim=1) if hasattr( |
|
|
self.config, 'pooling_method') and hasattr(output, self.config.pooling_method) else output |
|
|
output_dim = final_output.shape[-1] |
|
|
torch.save(final_output, tensor_blob) |
|
|
|
|
|
|
|
|
cursor.execute(f""" |
|
|
INSERT INTO {self.config.DB_TABLE_NAME} |
|
|
(name, architecture, image_path, image_id, prompt, label, layer, pooling_method, tensor_dim, tensor) |
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); |
|
|
""", ( |
|
|
self.model_path, |
|
|
self.config.architecture.value, |
|
|
image_path if isinstance(image_path, str) else None, |
|
|
row_id, |
|
|
prompt, |
|
|
label, |
|
|
name, |
|
|
self.config.pooling_method if hasattr(self.config, 'pooling_method') else None, |
|
|
output_dim, |
|
|
tensor_blob.getvalue()) |
|
|
) |
|
|
|
|
|
self.connection.commit() |
|
|
|
|
|
logging.debug( |
|
|
f'Ran hook and saved tensor for {image_path} using prompt ' |
|
|
f'{prompt} on layer {name}.' |
|
|
) |
|
|
|
|
|
return generate_states_hook |
|
|
|
|
|
def _register_module_hooks(self, |
|
|
model_input: ModelInput |
|
|
) -> List[torch.utils.hooks.RemovableHandle]: |
|
|
"""Register the generated hook function to the modules in the config. |
|
|
|
|
|
At the same time, we need to add in the image path itself and the prompt |
|
|
which will be used for the database input. |
|
|
|
|
|
Args: |
|
|
model_input (ModelInput): The input dictionary |
|
|
containing the image path, prompt, label (if applicable) and |
|
|
the data itself. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: Calls a runtime error if no hooks were registered |
|
|
|
|
|
Returns: |
|
|
List[torch.utils.hooks.RemovableHandle]: A list of handles that one |
|
|
can remove after the forward pass. |
|
|
""" |
|
|
logging.debug( |
|
|
f'Registering module hook for {model_input["image"]} using prompt "{model_input["prompt"]}"' |
|
|
) |
|
|
|
|
|
|
|
|
hooks = [] |
|
|
|
|
|
|
|
|
for name, module in self.model.named_modules(): |
|
|
if self.config.matches_module(name): |
|
|
hooks.append(module.register_forward_hook( |
|
|
self._generate_state_hook(name, model_input) |
|
|
)) |
|
|
logging.debug(f'Registered hook to {name}') |
|
|
|
|
|
if len(hooks) == 0: |
|
|
raise RuntimeError( |
|
|
'No hooks were registered. Double-check the configured modules.' |
|
|
) |
|
|
|
|
|
return hooks |
|
|
|
|
|
def _forward(self, data: BatchFeature) -> None: |
|
|
"""Given some input data, performs a single forward pass. |
|
|
|
|
|
This function itself can be overriden, while _hook_and_eval |
|
|
should be left in tact. |
|
|
|
|
|
Args: |
|
|
data (BatchFeature): The given data tensor. |
|
|
""" |
|
|
data.to(self.config.device) |
|
|
with torch.no_grad(): |
|
|
_ = self.model(**data) |
|
|
logging.debug('Completed forward pass...') |
|
|
|
|
|
def _hook_and_eval(self, model_input: ModelInput) -> None: |
|
|
"""Given some input, performs a single forward pass. |
|
|
|
|
|
Args: |
|
|
model_input (ModelInput): The given input dictionary. |
|
|
""" |
|
|
logging.debug('Starting forward pass') |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
hooks = self._register_module_hooks(model_input) |
|
|
|
|
|
|
|
|
self._forward(model_input['data']) |
|
|
|
|
|
for hook in hooks: |
|
|
hook.remove() |
|
|
logging.debug('Unregistered all hooks..') |
|
|
|
|
|
def _initialize_db(self) -> None: |
|
|
"""Initializes a database based on config.""" |
|
|
|
|
|
self.connection = sqlite3.connect(self.config.output_db) |
|
|
logging.debug(f'Database created at {self.config.output_db}') |
|
|
|
|
|
cursor = self.connection.cursor() |
|
|
|
|
|
|
|
|
cursor.execute( |
|
|
f""" |
|
|
CREATE TABLE IF NOT EXISTS {self.config.DB_TABLE_NAME} ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
name TEXT NOT NULL, |
|
|
architecture TEXT NOT NULL, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, |
|
|
image_path TEXT NULL, |
|
|
image_id INTEGER NULL, |
|
|
prompt TEXT NOT NULL, |
|
|
label TEXT NULL, |
|
|
layer TEXT NOT NULL, |
|
|
pooling_method TEXT NULL, |
|
|
tensor_dim INTEGER NOT NULL, |
|
|
tensor BLOB NOT NULL |
|
|
); |
|
|
""" |
|
|
) |
|
|
|
|
|
def _cleanup(self) -> None: |
|
|
"""Cleanups the database by closing the connection.""" |
|
|
self.connection.close() |
|
|
|
|
|
def _generate_processor_output(self, prompt: str, img_path: str | Image.Image) -> dict: |
|
|
"""Generate the processor outputs from the prompt and image path. |
|
|
|
|
|
Args: |
|
|
prompt (str): The generated prompt string with the input text and |
|
|
the image labels. |
|
|
img_path (str | Image.Image): The specified input image path or image object. |
|
|
|
|
|
Returns: |
|
|
dict: The corresponding processor output per image and prompt. |
|
|
""" |
|
|
data = { |
|
|
'text': prompt, |
|
|
'return_tensors': 'pt' |
|
|
} |
|
|
|
|
|
if img_path: |
|
|
img = Image.open(img_path) if isinstance(img_path, str) else img_path |
|
|
data['images'] = [img.convert('RGB')] |
|
|
|
|
|
return self.processor(**data) |
|
|
|
|
|
def _generate_prompt(self, prompt: str, add_generation_prompt: bool = True, has_images: bool = False) -> str: |
|
|
"""Generates the prompt string with the input messages. |
|
|
|
|
|
TODO: move `add_generation_prompt` to the config. |
|
|
[Note from Martin] I'd argue that we should keep it as a parameter here |
|
|
since in gradio we want to hack these parameters a bit. |
|
|
|
|
|
Args: |
|
|
prompt (str): The input prompt string. |
|
|
add_generation_prompt (bool): Whether to add a start token of a bot |
|
|
response. |
|
|
has_images (bool): Whether the model has images or not. |
|
|
|
|
|
Returns: |
|
|
str: The generated prompt with the input text and the image labels. |
|
|
""" |
|
|
logging.debug('Loading data...') |
|
|
|
|
|
input_msgs_formatted = [{ |
|
|
'role': 'user', |
|
|
'content': [] |
|
|
}] |
|
|
|
|
|
|
|
|
if self.config.has_images() or has_images: |
|
|
input_msgs_formatted[0]['content'].append({ |
|
|
'type': 'image' |
|
|
}) |
|
|
|
|
|
|
|
|
if prompt: |
|
|
input_msgs_formatted[0]['content'].append({ |
|
|
'type': 'text', |
|
|
'text': prompt |
|
|
}) |
|
|
|
|
|
|
|
|
return self.processor.apply_chat_template( |
|
|
input_msgs_formatted, |
|
|
add_generation_prompt=add_generation_prompt |
|
|
) |
|
|
|
|
|
def _load_input_data(self) -> Iterator[ModelInput]: |
|
|
"""From a configuration, loads the input image and text data. |
|
|
|
|
|
For each prompt and input image, create a separate batch feature that |
|
|
will be ran separately and saved separately within the database. |
|
|
|
|
|
Yields: |
|
|
List[ModelInput]: List of input data, this input data is made of |
|
|
a tuple of strings (first an image path, then a prompt) and |
|
|
a batch feature which is either a torch.Tensor or a dictionary. |
|
|
""" |
|
|
|
|
|
logging.debug('Generating embeddings through its processor...') |
|
|
if self.config.dataset: |
|
|
|
|
|
for row in self.config.dataset: |
|
|
prompt = self._generate_prompt(row['prompt']) |
|
|
data = self._generate_processor_output( |
|
|
prompt=prompt, |
|
|
img_path=row['image'] |
|
|
) |
|
|
|
|
|
yield { |
|
|
'image': row['image'], |
|
|
'prompt': row['prompt'], |
|
|
'label': row['label'] if 'label' in self.config.dataset.column_names else None, |
|
|
'data': data, |
|
|
'row_id': row['id'], |
|
|
} |
|
|
|
|
|
else: |
|
|
if not self.config.has_images(): |
|
|
yield { |
|
|
'image': self.config.NO_IMG_PROMPT, |
|
|
'prompt': self.config.prompt, |
|
|
'data': self._generate_processor_output( |
|
|
prompt=self._generate_prompt(), |
|
|
img_path=None |
|
|
) |
|
|
} |
|
|
else: |
|
|
prompt = self._generate_prompt(self.config.prompt) |
|
|
for img_path in self.config.image_paths: |
|
|
data = self._generate_processor_output( |
|
|
prompt=prompt, |
|
|
img_path=img_path |
|
|
) |
|
|
yield { |
|
|
'image': img_path, |
|
|
'prompt': self.config.prompt, |
|
|
'data': data |
|
|
} |
|
|
|
|
|
@property |
|
|
def _data_size(self) -> int: |
|
|
"""Returns the total number of data points. |
|
|
|
|
|
Returns: |
|
|
int: The total number of data points. |
|
|
""" |
|
|
if self.config.dataset: |
|
|
return len(self.config.dataset) |
|
|
else: |
|
|
if not self.config.has_images(): |
|
|
return 1 |
|
|
else: |
|
|
return len(self.config.image_paths) |
|
|
|
|
|
def run(self) -> None: |
|
|
"""Get the hidden states from the model and saving them.""" |
|
|
|
|
|
self._initialize_db() |
|
|
|
|
|
|
|
|
self.model.to(self.config.device) |
|
|
|
|
|
|
|
|
if self.config.device.type == 'cuda': |
|
|
torch.cuda.reset_peak_memory_stats(self.config.device) |
|
|
|
|
|
|
|
|
for item in tqdm.tqdm(self._load_input_data(), desc='Running forward hooks on data', total=self._data_size): |
|
|
self._hook_and_eval(item) |
|
|
|
|
|
|
|
|
if self.config.device.type == 'cuda': |
|
|
logging.debug(f'Peak GPU memory allocated: {torch.cuda.max_memory_allocated(self.config.device) / 1e6:.2f} MB') |
|
|
|
|
|
|
|
|
self._cleanup() |
|
|
|