Spaces:
Running
Running
| """ | |
| PyTorch ๋ชจ๋ธ์ TensorFlow Lite ํ์์ผ๋ก ๋ณํํ๋ ์คํฌ๋ฆฝํธ | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import os | |
| # ์ ํ์ ์ํฌํธ | |
| ONNX_AVAILABLE = False | |
| TF_AVAILABLE = False | |
| ONNX_TF_AVAILABLE = False | |
| try: | |
| import onnx | |
| ONNX_AVAILABLE = True | |
| except (ImportError, SyntaxError, Exception) as e: | |
| ONNX_AVAILABLE = False | |
| if not isinstance(e, ImportError): | |
| print(f"โ ๏ธ onnx ํจํค์ง ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {type(e).__name__}") | |
| try: | |
| import tensorflow as tf | |
| TF_AVAILABLE = True | |
| except (ImportError, SyntaxError, Exception) as e: | |
| TF_AVAILABLE = False | |
| if not isinstance(e, ImportError): | |
| print(f"โ ๏ธ tensorflow ํจํค์ง ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {type(e).__name__}") | |
| try: | |
| # onnx-tf๋ ์ค์ ๋ก ์ฌ์ฉํ ๋ ์ํฌํธํ๋๋ก ๋ณ๊ฒฝ | |
| # from onnx_tf.backend import prepare | |
| ONNX_TF_AVAILABLE = True | |
| except (ImportError, SyntaxError, Exception) as e: | |
| ONNX_TF_AVAILABLE = False | |
| if not isinstance(e, ImportError): | |
| print(f"โ ๏ธ onnx-tf ํจํค์ง ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {type(e).__name__}") | |
| class FatigueNet(nn.Module): | |
| """CNN + GRU ๊ธฐ๋ฐ ํผ๋ก๋ ์์ธก ๋ชจ๋ธ (PyTorch ๋ฒ์ )""" | |
| def __init__(self, input_dim=2, hidden_dim=64, num_layers=2, output_dim=1): | |
| super(FatigueNet, self).__init__() | |
| # CNN ๋ถ๋ถ | |
| self.conv1 = nn.Conv1d( | |
| in_channels=input_dim, | |
| out_channels=32, | |
| kernel_size=1, | |
| padding=0 | |
| ) | |
| self.conv2 = nn.Conv1d( | |
| in_channels=32, | |
| out_channels=64, | |
| kernel_size=1, | |
| padding=0 | |
| ) | |
| self.relu = nn.ReLU() | |
| # GRU ๋ถ๋ถ (TFLite ํธํ์ฑ์ ์ํด linear_before_reset=False) | |
| self.gru = nn.GRU( | |
| input_size=64, | |
| hidden_size=hidden_dim, | |
| num_layers=num_layers, | |
| batch_first=True, | |
| dropout=0.2 if num_layers > 1 else 0 | |
| ) | |
| # Fully Connected ๋ ์ด์ด | |
| self.fc = nn.Linear(hidden_dim, output_dim) | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x): | |
| if x.dim() == 2: | |
| x = x.unsqueeze(1) | |
| x = x.permute(0, 2, 1) | |
| x = self.conv1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.relu(x) | |
| x = x.permute(0, 2, 1) | |
| gru_out, _ = self.gru(x) | |
| last_output = gru_out[:, -1, :] | |
| last_output = self.dropout(last_output) | |
| output = self.fc(last_output) | |
| return output | |
| def convert_pytorch_to_tflite( | |
| pytorch_model_path='./model/fatigue_net_v2.pt', | |
| tflite_model_path='./model/fatigue_net_v2.tflite', | |
| input_shape=(1, 1, 2) # (batch, seq_len, features) | |
| ): | |
| """ | |
| PyTorch ๋ชจ๋ธ์ TensorFlow Lite๋ก ๋ณํ | |
| Args: | |
| pytorch_model_path: PyTorch ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก | |
| tflite_model_path: ์ ์ฅํ TFLite ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก | |
| input_shape: ์ ๋ ฅ ํ ์ ํํ (batch, seq_len, features) | |
| """ | |
| print("=" * 80) | |
| print("PyTorch ๋ชจ๋ธ์ TensorFlow Lite๋ก ๋ณํ") | |
| print("=" * 80) | |
| # ํ์ ํจํค์ง ํ์ธ | |
| if not ONNX_AVAILABLE or not TF_AVAILABLE or not ONNX_TF_AVAILABLE: | |
| print("\nโ ํ์ ํจํค์ง๊ฐ ์ค์น๋์ง ์์๊ฑฐ๋ ํธํ์ฑ ๋ฌธ์ ๊ฐ ์์ต๋๋ค.") | |
| print("\n๐ Python ๋ฒ์ ํ์ธ:") | |
| import sys | |
| print(f" ํ์ฌ Python ๋ฒ์ : {sys.version}") | |
| print(f" ๊ถ์ฅ Python ๋ฒ์ : 3.10 ์ด์") | |
| if sys.version_info < (3, 10): | |
| print("\nโ ๏ธ Python 3.9์์๋ ์ผ๋ถ ํจํค์ง ํธํ์ฑ ๋ฌธ์ ๊ฐ ์์ ์ ์์ต๋๋ค.") | |
| print(" Python 3.10 ์ด์์ผ๋ก ์ ๊ทธ๋ ์ด๋ํ๊ฑฐ๋, ๋ค์์ ์๋ํ์ธ์:") | |
| print(" - ๊ฐ์ํ๊ฒฝ์์ Python 3.10+ ์ฌ์ฉ") | |
| print(" - ๋๋ ํธํ๋๋ ํจํค์ง ๋ฒ์ ์ค์น") | |
| print("\n๐ฆ ์ค์น ๋ช ๋ น์ด:") | |
| print(" ๊ถ์ฅ ๋ฒ์ (Python 3.10 ์ด์):") | |
| print(" pip install onnx==1.15.0 onnx-tf==1.10.0 tensorflow==2.15.0") | |
| print("\nโ ๏ธ ์ฐธ๊ณ : Python 3.9์์๋ ์ผ๋ถ ํจํค์ง ์ค์น ์ค ์๋ฌ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.") | |
| print(" Python 3.10 ์ด์ ์ฌ์ฉ์ ๊ฐ๋ ฅํ ๊ถ์ฅํฉ๋๋ค.") | |
| print("\nโ TFLite ๋ณํ์ ํ์์ ๋๋ค. ๋ชจ๋ฐ์ผ ๋๋ฐ์ด์ค์์ ์คํํ๊ธฐ ์ํด ํ์ํฉ๋๋ค.") | |
| print(" ํ์ ํจํค์ง๋ฅผ ์ค์นํ๊ณ ๋ค์ ์๋ํ์ธ์.") | |
| return False | |
| # 1๏ธโฃ PyTorch ๋ชจ๋ธ ๋ก๋ | |
| print("\n1๏ธโฃ PyTorch ๋ชจ๋ธ ๋ก๋ ์ค...") | |
| if not os.path.exists(pytorch_model_path): | |
| raise FileNotFoundError(f"๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {pytorch_model_path}") | |
| # TorchScript ํ์ผ์ธ์ง ๋จผ์ ํ์ธ | |
| try: | |
| checkpoint = torch.jit.load(pytorch_model_path, map_location='cpu') | |
| if isinstance(checkpoint, torch.jit.ScriptModule): | |
| raise ValueError( | |
| f"โ {pytorch_model_path}๋ TorchScript ํ์์ ๋๋ค.\n" | |
| "TFLite ๋ณํ์ ์ํด์๋ PyTorch state_dict ํ์ ๋ชจ๋ธ์ด ํ์ํฉ๋๋ค.\n" | |
| "๋ชจ๋ธ์ ๋ค์ ํ์ตํ๊ฑฐ๋ ์ฌ๋ฐ๋ฅธ ํ์์ ๋ชจ๋ธ ํ์ผ์ ์ฌ์ฉํ์ธ์." | |
| ) | |
| except: | |
| pass | |
| # ์ผ๋ฐ PyTorch ๋ชจ๋ธ ๋ก๋ | |
| checkpoint = torch.load(pytorch_model_path, map_location='cpu') | |
| # ์ผ๋ฐ PyTorch ๋ชจ๋ธ์ธ์ง ํ์ธ | |
| if not isinstance(checkpoint, dict) or 'model_state_dict' not in checkpoint: | |
| raise ValueError( | |
| f"โ ์ฌ๋ฐ๋ฅธ PyTorch ๋ชจ๋ธ ํ์์ด ์๋๋๋ค.\n" | |
| f"'{pytorch_model_path}' ํ์ผ์ 'model_state_dict' ํค๊ฐ ํ์ํฉ๋๋ค.\n" | |
| "๋ชจ๋ธ์ ๋ค์ ํ์ตํ๊ฑฐ๋ ์ฌ๋ฐ๋ฅธ ํ์์ ๋ชจ๋ธ ํ์ผ์ ์ฌ์ฉํ์ธ์." | |
| ) | |
| model_config = checkpoint.get('model_config', { | |
| 'input_dim': 2, | |
| 'hidden_dim': 64, | |
| 'num_layers': 2, | |
| 'output_dim': 1 | |
| }) | |
| model = FatigueNet(**model_config) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| print(f"โ ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {pytorch_model_path}") | |
| print(f" ๋ชจ๋ธ ์ค์ : {model_config}\n") | |
| # 2๏ธโฃ ONNX๋ก ๋ณํ | |
| print("2๏ธโฃ ONNX ํ์์ผ๋ก ๋ณํ ์ค...") | |
| onnx_model_path = './model/fatigue_net_v2.onnx' | |
| os.makedirs('./model', exist_ok=True) | |
| # ๋๋ฏธ ์ ๋ ฅ ์์ฑ (๊ณ ์ batch_size=1๋ก TFLite ํธํ์ฑ ํฅ์) | |
| dummy_input = torch.randn(1, 1, 2) # (batch=1, seq_len=1, features=2) | |
| try: | |
| # GRU๋ฅผ RNN์ผ๋ก ๋ณํํ๊ฑฐ๋ TFLite ํธํ ์ต์ ์ฌ์ฉ | |
| torch.onnx.export( | |
| model, | |
| dummy_input, | |
| onnx_model_path, | |
| export_params=True, | |
| opset_version=11, # onnx-tf ํธํ์ฑ์ ์ํด 11๋ก ๋ฎ์ถค | |
| do_constant_folding=True, | |
| input_names=['input'], | |
| output_names=['output'], | |
| dynamic_axes={ | |
| 'input': {0: 'batch_size', 1: 'sequence_length'}, | |
| 'output': {0: 'batch_size'} | |
| }, | |
| # GRU ๊ด๋ จ ํธํ์ฑ ์ต์ | |
| custom_opsets=None, | |
| verbose=False | |
| ) | |
| print(f"โ ONNX ๋ณํ ์๋ฃ: {onnx_model_path}\n") | |
| except Exception as e: | |
| print(f"โ ๏ธ ONNX ๋ณํ ์ค ๊ฒฝ๊ณ (๊ณ์ ์งํ): {e}\n") | |
| # 3๏ธโฃ ONNX๋ฅผ TensorFlow๋ก ๋ณํ | |
| print("3๏ธโฃ TensorFlow ํ์์ผ๋ก ๋ณํ ์ค...") | |
| try: | |
| from onnx_tf.backend import prepare | |
| # ONNX ๋ชจ๋ธ ๋ก๋ ๋ฐ GRU ์์ฑ ์์ | |
| onnx_model = onnx.load(onnx_model_path) | |
| # GRU ๋ ธ๋์ linear_before_reset ์์ฑ์ 0์ผ๋ก ์ค์ (TensorFlow ํธํ) | |
| for node in onnx_model.graph.node: | |
| if node.op_type == 'GRU': | |
| # linear_before_reset ์์ฑ์ ์ฐพ์์ 0์ผ๋ก ์ค์ | |
| for attr in node.attribute: | |
| if attr.name == 'linear_before_reset': | |
| attr.i = 0 | |
| break | |
| else: | |
| # linear_before_reset ์์ฑ์ด ์์ผ๋ฉด ์ถ๊ฐ | |
| attr = onnx.helper.make_attribute('linear_before_reset', 0) | |
| node.attribute.append(attr) | |
| tf_rep = prepare(onnx_model) | |
| # TensorFlow SavedModel๋ก ์ ์ฅ | |
| tf_model_path = './model/tf_model' | |
| tf_rep.export_graph(tf_model_path) | |
| print(f"โ TensorFlow ๋ณํ ์๋ฃ: {tf_model_path}\n") | |
| except Exception as e: | |
| print(f"โ TensorFlow ๋ณํ ์คํจ: {e}") | |
| print("โ ๏ธ ONNX-TF ๋ณํ์ด ์คํจํ์ต๋๋ค.\n") | |
| print("โ TFLite ๋ณํ์ ํ์์ ๋๋ค. ๋ชจ๋ฐ์ผ ๋๋ฐ์ด์ค์์ ์คํํ๊ธฐ ์ํด ํ์ํฉ๋๋ค.") | |
| print(" ์๋ฌ๋ฅผ ํด๊ฒฐํ๊ณ ๋ค์ ์๋ํ์ธ์.") | |
| return False | |
| # 4๏ธโฃ TensorFlow Lite๋ก ๋ณํ | |
| print("4๏ธโฃ TensorFlow Lite ํ์์ผ๋ก ๋ณํ ์ค...") | |
| # TensorFlow Lite ๋ณํ๊ธฐ ์์ฑ | |
| converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path) | |
| # GRU ๋ฑ ๋ณต์กํ ์ฐ์ฐ์ ์ํ ์ค์ | |
| converter.target_spec.supported_ops = [ | |
| tf.lite.OpsSet.TFLITE_BUILTINS, | |
| tf.lite.OpsSet.SELECT_TF_OPS | |
| ] | |
| converter._experimental_lower_tensor_list_ops = False | |
| # ์ต์ ํ ์ต์ ์ค์ (์ ํ์ฌํญ) | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| # ๋ณํ ์คํ | |
| tflite_model = converter.convert() | |
| # TFLite ๋ชจ๋ธ ์ ์ฅ | |
| with open(tflite_model_path, 'wb') as f: | |
| f.write(tflite_model) | |
| print(f"โ TensorFlow Lite ๋ณํ ์๋ฃ: {tflite_model_path}") | |
| # ๋ชจ๋ธ ํฌ๊ธฐ ํ์ธ | |
| model_size = os.path.getsize(tflite_model_path) / (1024 * 1024) # MB | |
| print(f" ๋ชจ๋ธ ํฌ๊ธฐ: {model_size:.2f} MB\n") | |
| # 5๏ธโฃ ๋ณํ๋ ๋ชจ๋ธ ํ ์คํธ | |
| print("5๏ธโฃ ๋ณํ๋ ๋ชจ๋ธ ํ ์คํธ ์ค...") | |
| try: | |
| interpreter = tf.lite.Interpreter(model_path=tflite_model_path) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| print(f" ์ ๋ ฅ ํํ: {input_details[0]['shape']}") | |
| print(f" ์ถ๋ ฅ ํํ: {output_details[0]['shape']}") | |
| # ํ ์คํธ ์ ๋ ฅ (๊ณ ์ ํฌ๊ธฐ) | |
| test_input = np.random.randn(1, 1, 2).astype(np.float32) | |
| interpreter.set_tensor(input_details[0]['index'], test_input) | |
| interpreter.invoke() | |
| test_output = interpreter.get_tensor(output_details[0]['index']) | |
| print(f" ํ ์คํธ ์ถ๋ ฅ: {test_output[0][0]:.4f}") | |
| print(" โ ๋ชจ๋ธ ํ ์คํธ ์ฑ๊ณต\n") | |
| except Exception as e: | |
| print(f" โ ๏ธ ๋ชจ๋ธ ํ ์คํธ ์ค ๊ฒฝ๊ณ : {e}") | |
| print(" (๋ชจ๋ธ์ ์์ฑ๋์์ง๋ง ํ ์คํธ๋ ์คํจํ์ต๋๋ค. ๋ชจ๋ฐ์ผ ๋๋ฐ์ด์ค์์ Flex ops๊ฐ ํ์ํ ์ ์์ต๋๋ค.)\n") | |
| # ์ค๊ฐ ํ์ผ ์ ๋ฆฌ (์ ํ์ฌํญ) | |
| print("6๏ธโฃ ์ค๊ฐ ํ์ผ ์ ๋ฆฌ ์ค...") | |
| try: | |
| os.remove(onnx_model_path) | |
| import shutil | |
| shutil.rmtree(tf_model_path) | |
| print("โ ์ค๊ฐ ํ์ผ ์ ๋ฆฌ ์๋ฃ\n") | |
| except Exception as e: | |
| print(f"โ ๏ธ ์ค๊ฐ ํ์ผ ์ ๋ฆฌ ์คํจ (๋ฌด์ ๊ฐ๋ฅ): {e}\n") | |
| print("=" * 80) | |
| print(f"โ ๋ณํ ์๋ฃ!") | |
| print(f" TFLite ๋ชจ๋ธ: {tflite_model_path}") | |
| print("=" * 80) | |
| return True | |
| def main(): | |
| """๋ฉ์ธ ํจ์""" | |
| try: | |
| success = convert_pytorch_to_tflite( | |
| pytorch_model_path='./model/fatigue_net_v2.pt', | |
| tflite_model_path='./model/fatigue_net_v2.tflite' | |
| ) | |
| if not success: | |
| return 1 | |
| except Exception as e: | |
| print(f"\nโ ๋ณํ ์คํจ: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |