stm32-modelzoo-app / api /README.md
FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
## 🧠 `get_model` API
The `get_model` function allows you to dynamically retrieve a model from the internal model registry based on key parameters like architecture, dataset, task type, and framework.
---
### Function Signature
```python
get_model(cfg)
```
### Variables expected from CFG
| Name | Type | Required | Description |
|------------------|-----------|----------|-----------------------------------------------------------------------------|
| `model_name` | `str` | βœ… | Name of the model architecture (e.g., `'resnet18'`, `'mobilenet_v2'`) |
| `use_case` | `str` | βœ… | Task type (e.g., `"image_classification"`, `'segmentation'`) |
| `framework` | `str` | βœ… | Framework to use: `'torch'` or `'tensorflow'` |
Plus many more according to model you are loading.
### Returns
- A model object initialized with the specified parameters.
- If `pretrained=True`, the model includes pretrained weights (if available).
---
### Notes
- If the combination of parameters doesn't match a known model, the function will raise an error.
- You can pass additional model-specific keyword arguments such as `dropout_rate=0.3` or `input_shape=(224, 224, 3)` depending on the framework.
### Example Usage
```python
from api import get_model
# load hydra config with required variables, It must have:
cfg.use_case = "image_classification"
cfg.model.framework = "torch"
cfg.model.model_name = "xcit_tiny_12_p8_224_timm"
model = get_model(cfg)
```
## `list_models` API
A helper function to list all existing models based on text filters.
You can provide a string or a list of strings to filter by model name, dataset name, task type, framework, or combined keys like `"model_name_dataset_name"`.
---
### Function Signature
```python
list_models(
filter_string='',
match_all=True,
print_table=True,
with_checkpoint=False,
)
```
### Parameters
| Name | Type | Required | Description |
|-------------------|----------------|----------|-----------------------------------------------------------------------------------------------------------|
| `filter_string` | `str` or `list`| ❌ | String or list of strings containing model name, dataset name, task type, framework, or combined filters. |
| `match_all` | `bool` | ❌ | If `True`, only models matching *all* keywords (intersection) are returned. If `False`, models matching *any* keyword are returned. (default: `True`) |
| `print_table` | `bool` | ❌ | Whether to print a table of matched models. Returns a list. (default: `True`) |
| `with_checkpoint` | `bool` | ❌ | Whether to include only models that have available checkpoints. (default: `False`) |
### Returns
- If `print_table=True`, prints a formatted table of matched models.
- If `print_table=False`, returns a list of matched model identifiers.
---
### Notes
- The `filter_string` can be a single string or a list of strings to filter the models.
- Setting `match_all=True` returns models matching *all* provided keywords (logical AND).
- Setting `match_all=False` returns models matching *any* of the keywords (logical OR).
- This function is useful to quickly explore what models are available in the registry based on flexible filters.
### Example Usage
```python
from api import list_models
# List and print models matching both 'resnet' and 'imagenet'
list_models(filter_string=['resnet', 'imagenet'], match_all=True)
# Get list of models matching either 'mobilenet' or 'cifar10', without printing
models = list_models(filter_string=['mobilenet', 'cifar10'], match_all=False, print_table=False)
print(models)
```
### Keywords for model names
```
'airnext', 'aim', 'alexnet', 'bagnet', 'beit', 'botnet', 'botnext', 'byobnet', 'cait', 'caformer',
'channelnet', 'coat', 'convformer', 'convnext', 'darknet', 'darts', 'deit', 'densnet', 'dicenet', 'diracnet',
'dla', 'dpn', 'drnc', 'drnd', 'edgenet', 'edgenext', 'efficientformer', 'efficientnet', 'espnet', 'eva',
'fasternet', 'fbnet', 'fishnet', 'focalnet', 'gmlp', 'gernet', 'ghostnet', 'googlenet', 'halonet', 'halonext',
'hardcorenas', 'hardnet', 'hgera', 'hgnet', 'hrnet', 'igc', 'inception', 'irevnet', 'lcnet', 'mambaout',
'mixer', 'mixnet', 'mobilenet', 'msdnet', 'nasnet', 'nest', 'nfnet', 'pvt', 'peleenet', 'pit', 'proxylessnas',
'pyramidnet', 'rdnet', 'regnet', 'res2net', 'resattnet', 'resmlp', 'resnet', 'resnest', 'resnext', 'rexnet',
'scnet', 'selecsls', 'senet', 'sequencer', 'shufflenet', 'sknet', 'sparsenet', 'sqnxt', 'squeezenet',
'starnet', 'swiftformer', 'swin', 'tinynet', 'tnt', 'tresnet', 'twins', 'vit', 'vitamin', 'vgg', 'volo',
'vovnet', 'xcit', 'xception', 'zfnet'
```
## `get_dataloaders` API
Loads and returns training and testing dataloaders for a specified dataset, task type, and framework.
### πŸ”§ Function Signature
```python
get_dataloaders(cfg)
```
### Variables expected from CFG
| Name | Type | Required | Description |
| -------------- | ------ | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `data_root` | `str` | βœ… | Path to the root directory containing the dataset folder. |
| `dataset_name` | `str` | βœ… | Name of the dataset (e.g., `'imagenet'`, `'flowers102'`, etc.). |
| `use_case` | `str` | βœ… | Task type (e.g., `"image_classification"`, `'detection'`, `'segmentation'`). |
| `framework` | `str` | βœ… | Framework or model name used to select the appropriate dataset wrapper.
| `prediction_path` | `str` | ❌ | Path to folder containing images or folders of images.
| `qunatization_path` | `str` | ❌ | Path to folder containing images or folders of images. |
| `**kwargs` | `dict` | ❌ | Additional keyword arguments passed to the dataset loader. Common options include:<ul><li>`download` (`bool`) – Whether to download the dataset (if supported).</li><li>`batch_size` (`int`) – Batch size for the dataloaders.</li><li>`img_size` (`int`)</li><li>`num_workers` (`int`) – Number of subprocesses for data loading.</li></ul> |
## Returns
Returns a dictionary containing PyTorch dataloaders for training and testing splits.
```python
Dict
{
'train': torch.utils.data.DataLoader,
'test' : torch.utils.data.DataLoader,
'valid': torch.utils.data.DataLoader,
'pred' : torch.utils.data.DataLoader,
'quant': torch.utils.data.DataLoader,
}
```
## Example Usage
```python
from api import get_dataloaders
cfg.dataset.data_dir= "/neutrino/datasets/"
cfg.dataset.dataset_name = "imagenet"
cfg.use_case = "image_classification"
cfg.model.framework = "torch"
# plus other variables like batch_size, input_size, aumentations, num_classes etc etc
dataloaders = get_dataloaders(cfg)
train_loader = dataloaders['train']
test_loader = dataloaders['test']
```
## Expected folder structure in standard datasets
- FLOWERS102 shud have 'jpg' folder, 'setid.mat' and 'imagelabels.mat' inside.
- VWW should have 'all' , 'annotations/instances_train.json' and 'annotations/instances_train.json' inside.
- FOOD101 should have 'images' and 'meta' folder inside
- IMAGENET should have 'train' and 'val' folder inside it with sub folders of classes.
## Project Structure Overview
```
project/
β”œβ”€β”€ apis/
β”‚ β”œβ”€β”€ get_model.py
β”‚ β”œβ”€β”€ get_dataset.py
β”‚ └── get_trainers.py
β”‚
β”œβ”€β”€ common/
β”‚ β”œβ”€β”€ blocks/ # Shared building blocks
β”‚ β”œβ”€β”€ registry/
β”‚ β”œβ”€β”€ model_registry.py # Model registry system
β”‚ β”œβ”€β”€ dataset_registry.py # Dataset registry system
β”‚ └── trainer_registry.py # Trainer registry system
β”‚
β”œβ”€β”€ image_classification/
β”‚ β”œβ”€β”€ config.py # Config file (e.g. for args, yaml loading, etc.)
β”‚ β”œβ”€β”€ main.py # Entry point: loads config, gets model/dataset/trainer, trains
β”‚ └── pt/
β”‚ β”œβ”€β”€ src/
β”‚ β”‚ β”œβ”€β”€ models/ # Model definitions
β”‚ β”‚ β”œβ”€β”€ dataset/ # Dataset definitions
β”‚ β”‚ └── trainers/ # Training logic
β”‚ β”‚
β”‚ └── wrapper/
β”‚ └── models/ # Wrapper to unify model interfaces and register with registry
β”‚
└── README.md
```
## System Flow
```
main.py (reads config)
↓
apis/ (get_model, get_dataset, get_trainers)
↓
common/registry (e.g., model_registry.get)
↓
image_classification/pt/wrapper/models/ (model registration)
↓
image_classification/pt/src/models/ (actual model implementation)
```
## Important to Avoid Circular Imports
Make sure:
common/ does not import anything from image_classification/
src/trainers/ does not call get_model() from apis/ β€” it should just accept already-prepared objects
wrapper/models/ only handles registration logic (and doesn't depend on training or dataset logic)