File size: 650 Bytes
591bc46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

import json
import os

def load_model_config(model_name, models_dir="models"):
    
    config_path = os.path.join(models_dir, f"{model_name}_config.json")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found for model '{model_name}' at '{config_path}'")
    
    with open(config_path, "r") as f:
        config = json.load(f)
    
    return config

def get_input_size(config):
    
    if "input_size" in config:
        return config["input_size"]
    elif "input_dim" in config:
        return config["input_dim"]
    else:
        raise ValueError("Input size not found in the model configuration.")