Spaces:
Sleeping
Sleeping
| """ | |
| ์ ํธ๋ฆฌํฐ ํจ์ | |
| """ | |
| import os | |
| import io | |
| import json | |
| import pickle | |
| import logging | |
| import contextlib | |
| import traceback | |
| import numpy as np | |
| import tensorflow as tf | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| class TqdmProgressCallback(tf.keras.callbacks.Callback): | |
| """ | |
| TensorFlow ํ๋ จ์ ์ํ ์ปค์คํ ์ฝ๋ฐฑ | |
| """ | |
| def __init__(self, epochs, verbose=1): | |
| super(TqdmProgressCallback, self).__init__() | |
| self.epochs = epochs | |
| self.verbose = verbose | |
| self.tqdm_bar = None | |
| def on_train_begin(self, logs=None): | |
| if self.verbose: | |
| self.tqdm_bar = tqdm(total=self.epochs, desc="Training", unit="epoch") | |
| def on_epoch_end(self, epoch, logs=None): | |
| if self.verbose: | |
| logs = logs or {} | |
| log_items = [] | |
| for k, v in logs.items(): | |
| if 'val_' not in k: # ํ๋ จ ์งํ๋ง ์ถ๋ ฅ | |
| log_items.append(f"{k}: {v:.4f}") | |
| desc = ", ".join(log_items) | |
| self.tqdm_bar.set_description(desc) | |
| self.tqdm_bar.update(1) | |
| def on_train_end(self, logs=None): | |
| if self.verbose and self.tqdm_bar is not None: | |
| self.tqdm_bar.close() | |
| print("ํ์ต ์๋ฃ!") | |
| def get_project_root(): | |
| """ | |
| ํ๋ก์ ํธ ๋ฃจํธ ๋๋ ํ ๋ฆฌ๋ฅผ ๋ฐํํฉ๋๋ค | |
| """ | |
| return Path(__file__).parent.parent.parent | |
| def ensure_directory(directory_path): | |
| """ | |
| ๋๋ ํ ๋ฆฌ๊ฐ ์กด์ฌํ์ง ์์ผ๋ฉด ์์ฑํฉ๋๋ค. | |
| """ | |
| Path(directory_path).mkdir(parents=True, exist_ok=True) | |
| return Path(directory_path) | |
| def normalize_path(path_str, base_dir=None): | |
| """ | |
| ์๋ ๊ฒฝ๋ก๋ฅผ ์ ๋ ๊ฒฝ๋ก๋ก ๋ณํํฉ๋๋ค. | |
| """ | |
| path = Path(path_str) | |
| if path.is_absolute(): | |
| return path | |
| if base_dir is None: | |
| base_dir = get_project_root() | |
| return Path(base_dir) / path | |
| def save_model(model, model_path, config=None, encoders=None): | |
| """ | |
| ๋ชจ๋ธ์ TensorFlow Lite ํ์์ผ๋ก ์ ์ฅํฉ๋๋ค. | |
| """ | |
| model_path = Path(model_path) | |
| ensure_directory(model_path.parent) | |
| # TensorFlow ๋ก๊ทธ ๋ ๋ฒจ ์์ ์กฐ์ | |
| original_tf_log_level = os.environ.get('TF_CPP_MIN_LOG_LEVEL', '') | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| tf_logger = logging.getLogger('tensorflow') | |
| original_tf_level = tf_logger.level | |
| tf_logger.setLevel(logging.ERROR) | |
| try: | |
| # .tflite ํ์ฅ์๋ก ๋ณ๊ฒฝ | |
| if not str(model_path).endswith('.tflite'): | |
| model_path = Path(str(model_path).replace('.keras', '').replace('.h5', '') + '.tflite') | |
| # TensorFlow Lite ๋ณํ | |
| print("TensorFlow Lite ๋ชจ๋ธ๋ก ๋ณํ ์ค...") | |
| converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
| # LSTM ํธํ์ฑ์ ์ํ ์ค์ | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_ops = [ | |
| tf.lite.OpsSet.TFLITE_BUILTINS, # ๊ธฐ๋ณธ TFLite ์ฐ์ฐ | |
| tf.lite.OpsSet.SELECT_TF_OPS # ์ถ๊ฐ TensorFlow ์ฐ์ฐ ํ์ฉ | |
| ] | |
| converter._experimental_lower_tensor_list_ops = False # TensorList ๋ณํ ๋นํ์ฑํ | |
| converter.allow_custom_ops = True # ์ปค์คํ ์ฐ์ฐ ํ์ฉ | |
| # ๋ณํ ์คํ | |
| with contextlib.redirect_stdout(io.StringIO()): | |
| with contextlib.redirect_stderr(io.StringIO()): | |
| tflite_model = converter.convert() | |
| # TFLite ๋ชจ๋ธ ์ ์ฅ | |
| with open(model_path, 'wb') as f: | |
| f.write(tflite_model) | |
| print(f"TensorFlow Lite ๋ชจ๋ธ์ด {model_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| # ๋ชจ๋ธ ํ์ผ ๊ธฐ๋ณธ๋ช ์ถ์ถ | |
| model_stem = model_path.stem | |
| # ์ธ์ฝ๋ ์ ๋ณด ์ ์ฅ | |
| if encoders is not None: | |
| encoder_path = model_path.parent / f"{model_stem}_encoders.json" | |
| with open(encoder_path, 'w') as f: | |
| json.dump(encoders, f, indent=2) | |
| print(f"์ธ์ฝ๋ ์ ๋ณด๊ฐ {encoder_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| # ๋ชจ๋ธ ์ค์ ์ ์ฅ | |
| if config is not None: | |
| config_path = model_path.parent / f"{model_stem}_config.json" | |
| with open(config_path, 'w') as f: | |
| # ์ง๋ ฌํ ๊ฐ๋ฅํ ํํ๋ก ๋ณํ | |
| json_safe_config = {k: str(v) if not isinstance(v, (int, float, str, bool, list, dict)) else v | |
| for k, v in config.items()} | |
| json.dump(json_safe_config, f, indent=2) | |
| print(f"๋ชจ๋ธ ์ค์ ์ด {config_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| return True | |
| finally: | |
| # ์๋ ๋ก๊ทธ ์ค์ ๋ณต์ | |
| tf_logger.setLevel(original_tf_level) | |
| if original_tf_log_level: | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = original_tf_log_level | |
| else: | |
| os.environ.pop('TF_CPP_MIN_LOG_LEVEL', None) | |
| def load_tflite_model(model_path): | |
| """ | |
| TensorFlow Lite ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค. | |
| """ | |
| try: | |
| # TFLite ์ธํฐํ๋ฆฌํฐ ์์ฑ | |
| interpreter = tf.lite.Interpreter(model_path=str(model_path)) | |
| interpreter.allocate_tensors() | |
| print(f"TensorFlow Lite ๋ชจ๋ธ์ด {model_path}์์ ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ต๋๋ค.") | |
| return interpreter | |
| except Exception as e: | |
| print(f"TensorFlow Lite ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| print(traceback.format_exc()) | |
| return None | |
| def predict_with_tflite(interpreter, inputs, verbose=False): | |
| """ | |
| TensorFlow Lite ๋ชจ๋ธ๋ก ์์ธก ์ํ | |
| """ | |
| try: | |
| # ์ ๋ ฅ ํ ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| # ๊ฐ ์ ๋ ฅ ์ค์ | |
| for i, input_tensor in enumerate(inputs): | |
| if i < len(input_details): | |
| interpreter.set_tensor(input_details[i]['index'], input_tensor) | |
| # ์คํ | |
| interpreter.invoke() | |
| # ์ถ๋ ฅ ๊ฐ์ ธ์ค๊ธฐ | |
| outputs = [] | |
| for output_detail in output_details: | |
| output = interpreter.get_tensor(output_detail['index']) | |
| outputs.append(output) | |
| return outputs if len(outputs) > 1 else outputs[0] | |
| except Exception as e: | |
| print("์์ธก ์คํจ") | |
| return None | |
| def save_results(results, output_path, include_model=False): | |
| """ | |
| ์ต์ ํ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํฉ๋๋ค. | |
| """ | |
| output_path = Path(output_path) | |
| ensure_directory(output_path.parent) | |
| # ๊ฒฐ๊ณผ ๋ณต์ฌ๋ณธ ์์ฑ | |
| pickle_safe_results = { | |
| 'grid_results': [], | |
| 'best_config': results.get('best_config', {}) | |
| } | |
| # ๋ชจ๋ธ ๊ฐ์ฒด ์ ๊ฑฐํ ๊ฒฐ๊ณผ ๋ณต์ฌ | |
| results_list = results.get('results', []) | |
| if not results_list and 'best_result' in results: | |
| results_list = [results['best_result']] | |
| for result in results_list: | |
| result_copy = result.copy() | |
| if not include_model and 'model' in result_copy: | |
| del result_copy['model'] | |
| pickle_safe_results['grid_results'].append(result_copy) | |
| # ํ ์คํธ ๊ฒฐ๊ณผ ์ถ๊ฐ | |
| if 'test_backtest' in results: | |
| pickle_safe_results['test_backtest'] = results['test_backtest'] | |
| # ๊ฒฐ๊ณผ ์ ์ฅ | |
| with open(output_path, 'wb') as f: | |
| pickle.dump(pickle_safe_results, f) | |
| print(f"์ต์ ํ ๊ฒฐ๊ณผ๊ฐ {output_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| return True | |
| def save_metadata(metadata, output_path): | |
| """ | |
| ๋ฉํ๋ฐ์ดํฐ๋ฅผ JSON ํ์์ผ๋ก ์ ์ฅํฉ๋๋ค. | |
| """ | |
| output_path = Path(output_path) | |
| ensure_directory(output_path.parent) | |
| # ์ง๋ ฌํ ๊ฐ๋ฅํ ํํ๋ก ๋ณํ | |
| json_safe_metadata = {k: str(v) if not isinstance(v, (int, float, str, bool, list, dict)) else v | |
| for k, v in metadata.items()} | |
| with open(output_path, 'w') as f: | |
| json.dump(json_safe_metadata, f, indent=2) | |
| print(f"๋ฉํ๋ฐ์ดํฐ๊ฐ {output_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| return True | |