MuscleCare-Train-AI / convert_tflite.py
Merry99's picture
Spaces์šฉ ์ฝ”๋“œ๋งŒ ํฌํ•จ (๋ชจ๋ธ ํŒŒ์ผ ์ œ์™ธ)
2b83ee8
"""
PyTorch ๋ชจ๋ธ์„ TensorFlow Lite ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ
"""
import torch
import torch.nn as nn
import numpy as np
import os
# ์„ ํƒ์  ์ž„ํฌํŠธ
ONNX_AVAILABLE = False
TF_AVAILABLE = False
ONNX_TF_AVAILABLE = False
try:
import onnx
ONNX_AVAILABLE = True
except (ImportError, SyntaxError, Exception) as e:
ONNX_AVAILABLE = False
if not isinstance(e, ImportError):
print(f"โš ๏ธ onnx ํŒจํ‚ค์ง€ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {type(e).__name__}")
try:
import tensorflow as tf
TF_AVAILABLE = True
except (ImportError, SyntaxError, Exception) as e:
TF_AVAILABLE = False
if not isinstance(e, ImportError):
print(f"โš ๏ธ tensorflow ํŒจํ‚ค์ง€ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {type(e).__name__}")
try:
# onnx-tf๋Š” ์‹ค์ œ๋กœ ์‚ฌ์šฉํ•  ๋•Œ ์ž„ํฌํŠธํ•˜๋„๋ก ๋ณ€๊ฒฝ
# from onnx_tf.backend import prepare
ONNX_TF_AVAILABLE = True
except (ImportError, SyntaxError, Exception) as e:
ONNX_TF_AVAILABLE = False
if not isinstance(e, ImportError):
print(f"โš ๏ธ onnx-tf ํŒจํ‚ค์ง€ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {type(e).__name__}")
class FatigueNet(nn.Module):
"""CNN + GRU ๊ธฐ๋ฐ˜ ํ”ผ๋กœ๋„ ์˜ˆ์ธก ๋ชจ๋ธ (PyTorch ๋ฒ„์ „)"""
def __init__(self, input_dim=2, hidden_dim=64, num_layers=2, output_dim=1):
super(FatigueNet, self).__init__()
# CNN ๋ถ€๋ถ„
self.conv1 = nn.Conv1d(
in_channels=input_dim,
out_channels=32,
kernel_size=1,
padding=0
)
self.conv2 = nn.Conv1d(
in_channels=32,
out_channels=64,
kernel_size=1,
padding=0
)
self.relu = nn.ReLU()
# GRU ๋ถ€๋ถ„ (TFLite ํ˜ธํ™˜์„ฑ์„ ์œ„ํ•ด linear_before_reset=False)
self.gru = nn.GRU(
input_size=64,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=0.2 if num_layers > 1 else 0
)
# Fully Connected ๋ ˆ์ด์–ด
self.fc = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
if x.dim() == 2:
x = x.unsqueeze(1)
x = x.permute(0, 2, 1)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = x.permute(0, 2, 1)
gru_out, _ = self.gru(x)
last_output = gru_out[:, -1, :]
last_output = self.dropout(last_output)
output = self.fc(last_output)
return output
def convert_pytorch_to_tflite(
pytorch_model_path='./model/fatigue_net_v2.pt',
tflite_model_path='./model/fatigue_net_v2.tflite',
input_shape=(1, 1, 2) # (batch, seq_len, features)
):
"""
PyTorch ๋ชจ๋ธ์„ TensorFlow Lite๋กœ ๋ณ€ํ™˜
Args:
pytorch_model_path: PyTorch ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
tflite_model_path: ์ €์žฅํ•  TFLite ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
input_shape: ์ž…๋ ฅ ํ…์„œ ํ˜•ํƒœ (batch, seq_len, features)
"""
print("=" * 80)
print("PyTorch ๋ชจ๋ธ์„ TensorFlow Lite๋กœ ๋ณ€ํ™˜")
print("=" * 80)
# ํ•„์ˆ˜ ํŒจํ‚ค์ง€ ํ™•์ธ
if not ONNX_AVAILABLE or not TF_AVAILABLE or not ONNX_TF_AVAILABLE:
print("\nโŒ ํ•„์ˆ˜ ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์•˜๊ฑฐ๋‚˜ ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.")
print("\n๐Ÿ“‹ Python ๋ฒ„์ „ ํ™•์ธ:")
import sys
print(f" ํ˜„์žฌ Python ๋ฒ„์ „: {sys.version}")
print(f" ๊ถŒ์žฅ Python ๋ฒ„์ „: 3.10 ์ด์ƒ")
if sys.version_info < (3, 10):
print("\nโš ๏ธ Python 3.9์—์„œ๋Š” ์ผ๋ถ€ ํŒจํ‚ค์ง€ ํ˜ธํ™˜์„ฑ ๋ฌธ์ œ๊ฐ€ ์žˆ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
print(" Python 3.10 ์ด์ƒ์œผ๋กœ ์—…๊ทธ๋ ˆ์ด๋“œํ•˜๊ฑฐ๋‚˜, ๋‹ค์Œ์„ ์‹œ๋„ํ•˜์„ธ์š”:")
print(" - ๊ฐ€์ƒํ™˜๊ฒฝ์—์„œ Python 3.10+ ์‚ฌ์šฉ")
print(" - ๋˜๋Š” ํ˜ธํ™˜๋˜๋Š” ํŒจํ‚ค์ง€ ๋ฒ„์ „ ์„ค์น˜")
print("\n๐Ÿ“ฆ ์„ค์น˜ ๋ช…๋ น์–ด:")
print(" ๊ถŒ์žฅ ๋ฒ„์ „ (Python 3.10 ์ด์ƒ):")
print(" pip install onnx==1.15.0 onnx-tf==1.10.0 tensorflow==2.15.0")
print("\nโš ๏ธ ์ฐธ๊ณ : Python 3.9์—์„œ๋Š” ์ผ๋ถ€ ํŒจํ‚ค์ง€ ์„ค์น˜ ์ค‘ ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
print(" Python 3.10 ์ด์ƒ ์‚ฌ์šฉ์„ ๊ฐ•๋ ฅํžˆ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.")
print("\nโŒ TFLite ๋ณ€ํ™˜์€ ํ•„์ˆ˜์ž…๋‹ˆ๋‹ค. ๋ชจ๋ฐ”์ผ ๋””๋ฐ”์ด์Šค์—์„œ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
print(" ํ•„์ˆ˜ ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•˜๊ณ  ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”.")
return False
# 1๏ธโƒฃ PyTorch ๋ชจ๋ธ ๋กœ๋“œ
print("\n1๏ธโƒฃ PyTorch ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
if not os.path.exists(pytorch_model_path):
raise FileNotFoundError(f"๋ชจ๋ธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {pytorch_model_path}")
# TorchScript ํŒŒ์ผ์ธ์ง€ ๋จผ์ € ํ™•์ธ
try:
checkpoint = torch.jit.load(pytorch_model_path, map_location='cpu')
if isinstance(checkpoint, torch.jit.ScriptModule):
raise ValueError(
f"โŒ {pytorch_model_path}๋Š” TorchScript ํ˜•์‹์ž…๋‹ˆ๋‹ค.\n"
"TFLite ๋ณ€ํ™˜์„ ์œ„ํ•ด์„œ๋Š” PyTorch state_dict ํ˜•์‹ ๋ชจ๋ธ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.\n"
"๋ชจ๋ธ์„ ๋‹ค์‹œ ํ•™์Šตํ•˜๊ฑฐ๋‚˜ ์˜ฌ๋ฐ”๋ฅธ ํ˜•์‹์˜ ๋ชจ๋ธ ํŒŒ์ผ์„ ์‚ฌ์šฉํ•˜์„ธ์š”."
)
except:
pass
# ์ผ๋ฐ˜ PyTorch ๋ชจ๋ธ ๋กœ๋“œ
checkpoint = torch.load(pytorch_model_path, map_location='cpu')
# ์ผ๋ฐ˜ PyTorch ๋ชจ๋ธ์ธ์ง€ ํ™•์ธ
if not isinstance(checkpoint, dict) or 'model_state_dict' not in checkpoint:
raise ValueError(
f"โŒ ์˜ฌ๋ฐ”๋ฅธ PyTorch ๋ชจ๋ธ ํ˜•์‹์ด ์•„๋‹™๋‹ˆ๋‹ค.\n"
f"'{pytorch_model_path}' ํŒŒ์ผ์— 'model_state_dict' ํ‚ค๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.\n"
"๋ชจ๋ธ์„ ๋‹ค์‹œ ํ•™์Šตํ•˜๊ฑฐ๋‚˜ ์˜ฌ๋ฐ”๋ฅธ ํ˜•์‹์˜ ๋ชจ๋ธ ํŒŒ์ผ์„ ์‚ฌ์šฉํ•˜์„ธ์š”."
)
model_config = checkpoint.get('model_config', {
'input_dim': 2,
'hidden_dim': 64,
'num_layers': 2,
'output_dim': 1
})
model = FatigueNet(**model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"โœ… ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ: {pytorch_model_path}")
print(f" ๋ชจ๋ธ ์„ค์ •: {model_config}\n")
# 2๏ธโƒฃ ONNX๋กœ ๋ณ€ํ™˜
print("2๏ธโƒฃ ONNX ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ ์ค‘...")
onnx_model_path = './model/fatigue_net_v2.onnx'
os.makedirs('./model', exist_ok=True)
# ๋”๋ฏธ ์ž…๋ ฅ ์ƒ์„ฑ (๊ณ ์ • batch_size=1๋กœ TFLite ํ˜ธํ™˜์„ฑ ํ–ฅ์ƒ)
dummy_input = torch.randn(1, 1, 2) # (batch=1, seq_len=1, features=2)
try:
# GRU๋ฅผ RNN์œผ๋กœ ๋ณ€ํ™˜ํ•˜๊ฑฐ๋‚˜ TFLite ํ˜ธํ™˜ ์˜ต์…˜ ์‚ฌ์šฉ
torch.onnx.export(
model,
dummy_input,
onnx_model_path,
export_params=True,
opset_version=11, # onnx-tf ํ˜ธํ™˜์„ฑ์„ ์œ„ํ•ด 11๋กœ ๋‚ฎ์ถค
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size', 1: 'sequence_length'},
'output': {0: 'batch_size'}
},
# GRU ๊ด€๋ จ ํ˜ธํ™˜์„ฑ ์˜ต์…˜
custom_opsets=None,
verbose=False
)
print(f"โœ… ONNX ๋ณ€ํ™˜ ์™„๋ฃŒ: {onnx_model_path}\n")
except Exception as e:
print(f"โš ๏ธ ONNX ๋ณ€ํ™˜ ์ค‘ ๊ฒฝ๊ณ  (๊ณ„์† ์ง„ํ–‰): {e}\n")
# 3๏ธโƒฃ ONNX๋ฅผ TensorFlow๋กœ ๋ณ€ํ™˜
print("3๏ธโƒฃ TensorFlow ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ ์ค‘...")
try:
from onnx_tf.backend import prepare
# ONNX ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ GRU ์†์„ฑ ์ˆ˜์ •
onnx_model = onnx.load(onnx_model_path)
# GRU ๋…ธ๋“œ์˜ linear_before_reset ์†์„ฑ์„ 0์œผ๋กœ ์„ค์ • (TensorFlow ํ˜ธํ™˜)
for node in onnx_model.graph.node:
if node.op_type == 'GRU':
# linear_before_reset ์†์„ฑ์„ ์ฐพ์•„์„œ 0์œผ๋กœ ์„ค์ •
for attr in node.attribute:
if attr.name == 'linear_before_reset':
attr.i = 0
break
else:
# linear_before_reset ์†์„ฑ์ด ์—†์œผ๋ฉด ์ถ”๊ฐ€
attr = onnx.helper.make_attribute('linear_before_reset', 0)
node.attribute.append(attr)
tf_rep = prepare(onnx_model)
# TensorFlow SavedModel๋กœ ์ €์žฅ
tf_model_path = './model/tf_model'
tf_rep.export_graph(tf_model_path)
print(f"โœ… TensorFlow ๋ณ€ํ™˜ ์™„๋ฃŒ: {tf_model_path}\n")
except Exception as e:
print(f"โŒ TensorFlow ๋ณ€ํ™˜ ์‹คํŒจ: {e}")
print("โš ๏ธ ONNX-TF ๋ณ€ํ™˜์ด ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.\n")
print("โŒ TFLite ๋ณ€ํ™˜์€ ํ•„์ˆ˜์ž…๋‹ˆ๋‹ค. ๋ชจ๋ฐ”์ผ ๋””๋ฐ”์ด์Šค์—์„œ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
print(" ์—๋Ÿฌ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”.")
return False
# 4๏ธโƒฃ TensorFlow Lite๋กœ ๋ณ€ํ™˜
print("4๏ธโƒฃ TensorFlow Lite ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ ์ค‘...")
# TensorFlow Lite ๋ณ€ํ™˜๊ธฐ ์ƒ์„ฑ
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
# GRU ๋“ฑ ๋ณต์žกํ•œ ์—ฐ์‚ฐ์„ ์œ„ํ•œ ์„ค์ •
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter._experimental_lower_tensor_list_ops = False
# ์ตœ์ ํ™” ์˜ต์…˜ ์„ค์ • (์„ ํƒ์‚ฌํ•ญ)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# ๋ณ€ํ™˜ ์‹คํ–‰
tflite_model = converter.convert()
# TFLite ๋ชจ๋ธ ์ €์žฅ
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
print(f"โœ… TensorFlow Lite ๋ณ€ํ™˜ ์™„๋ฃŒ: {tflite_model_path}")
# ๋ชจ๋ธ ํฌ๊ธฐ ํ™•์ธ
model_size = os.path.getsize(tflite_model_path) / (1024 * 1024) # MB
print(f" ๋ชจ๋ธ ํฌ๊ธฐ: {model_size:.2f} MB\n")
# 5๏ธโƒฃ ๋ณ€ํ™˜๋œ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ
print("5๏ธโƒฃ ๋ณ€ํ™˜๋œ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์ค‘...")
try:
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f" ์ž…๋ ฅ ํ˜•ํƒœ: {input_details[0]['shape']}")
print(f" ์ถœ๋ ฅ ํ˜•ํƒœ: {output_details[0]['shape']}")
# ํ…Œ์ŠคํŠธ ์ž…๋ ฅ (๊ณ ์ • ํฌ๊ธฐ)
test_input = np.random.randn(1, 1, 2).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()
test_output = interpreter.get_tensor(output_details[0]['index'])
print(f" ํ…Œ์ŠคํŠธ ์ถœ๋ ฅ: {test_output[0][0]:.4f}")
print(" โœ… ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์„ฑ๊ณต\n")
except Exception as e:
print(f" โš ๏ธ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์ค‘ ๊ฒฝ๊ณ : {e}")
print(" (๋ชจ๋ธ์€ ์ƒ์„ฑ๋˜์—ˆ์ง€๋งŒ ํ…Œ์ŠคํŠธ๋Š” ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋ฐ”์ผ ๋””๋ฐ”์ด์Šค์—์„œ Flex ops๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)\n")
# ์ค‘๊ฐ„ ํŒŒ์ผ ์ •๋ฆฌ (์„ ํƒ์‚ฌํ•ญ)
print("6๏ธโƒฃ ์ค‘๊ฐ„ ํŒŒ์ผ ์ •๋ฆฌ ์ค‘...")
try:
os.remove(onnx_model_path)
import shutil
shutil.rmtree(tf_model_path)
print("โœ… ์ค‘๊ฐ„ ํŒŒ์ผ ์ •๋ฆฌ ์™„๋ฃŒ\n")
except Exception as e:
print(f"โš ๏ธ ์ค‘๊ฐ„ ํŒŒ์ผ ์ •๋ฆฌ ์‹คํŒจ (๋ฌด์‹œ ๊ฐ€๋Šฅ): {e}\n")
print("=" * 80)
print(f"โœ… ๋ณ€ํ™˜ ์™„๋ฃŒ!")
print(f" TFLite ๋ชจ๋ธ: {tflite_model_path}")
print("=" * 80)
return True
def main():
"""๋ฉ”์ธ ํ•จ์ˆ˜"""
try:
success = convert_pytorch_to_tflite(
pytorch_model_path='./model/fatigue_net_v2.pt',
tflite_model_path='./model/fatigue_net_v2.tflite'
)
if not success:
return 1
except Exception as e:
print(f"\nโŒ ๋ณ€ํ™˜ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())