| import os |
| import os.path as osp |
| import pickle |
| import torch |
| import json |
|
|
|
|
| def read_from_file(file_path): |
| """ |
| Read content from a file based on its extension. |
| |
| Args: |
| file_path (str): Path to the file. |
| |
| Returns: |
| content: Content of the file. |
| |
| Raises: |
| NotImplementedError: If the file type is not supported. |
| """ |
| if file_path.endswith('.txt'): |
| with open(file_path, 'r') as f: |
| return f.read() |
| elif file_path.endswith('.json'): |
| with open(file_path, 'r') as f: |
| return json.load(f) |
| elif file_path.endswith('.pkl'): |
| with open(file_path, 'rb') as f: |
| return pickle.load(f) |
| else: |
| raise NotImplementedError(f'File type not supported: {file_path}') |
|
|
|
|
| def write_to_file(file_path, content): |
| """ |
| Write content to a file based on its extension. |
| |
| Args: |
| file_path (str): Path to the file. |
| content: Content to write. |
| |
| Raises: |
| NotImplementedError: If the file type is not supported. |
| """ |
| if file_path.endswith('.txt'): |
| with open(file_path, 'w') as f: |
| f.write(content) |
| elif file_path.endswith('.json'): |
| with open(file_path, 'w') as f: |
| json.dump(content, f, indent=4) |
| elif file_path.endswith('.pkl'): |
| with open(file_path, 'wb') as f: |
| pickle.dump(content, f) |
| else: |
| raise NotImplementedError(f'File type not supported: {file_path}') |
|
|
|
|
| def save_files(save_path, **kwargs): |
| """ |
| Save multiple files in a specified directory. |
| |
| Args: |
| save_path (str): Directory to save the files. |
| **kwargs: Keyword arguments where keys are filenames (without extension) and values are the contents. |
| """ |
| os.makedirs(save_path, exist_ok=True) |
| for key, value in kwargs.items(): |
| if isinstance(value, dict): |
| with open(osp.join(save_path, f'{key}.pkl'), 'wb') as f: |
| pickle.dump(value, f) |
| elif isinstance(value, torch.Tensor): |
| torch.save(value, osp.join(save_path, f'{key}.pt')) |
| else: |
| raise NotImplementedError(f'File type not supported for key: {key}') |
|
|
|
|
| def load_files(save_path): |
| """ |
| Load all files from a specified directory. |
| |
| Args: |
| save_path (str): Directory to load the files from. |
| |
| Returns: |
| dict: Dictionary with filenames (without extension) as keys and file contents as values. |
| """ |
| loaded_dict = {} |
| for file in os.listdir(save_path): |
| if os.path.isdir(osp.join(save_path, file)): |
| continue |
| file_path = osp.join(save_path, file) |
| file_name, file_ext = osp.splitext(file) |
| if file_ext == '.pkl': |
| with open(file_path, 'rb') as f: |
| loaded_dict[file_name] = pickle.load(f) |
| elif file_ext == '.pt': |
| loaded_dict[file_name] = torch.load(file_path) |
| else: |
| raise NotImplementedError(f'File type not supported: {file}') |
| return loaded_dict |
|
|