| import torch |
| import os |
| import onnx |
| from model import GPTConfig, GPT |
|
|
| |
| ckpt_path = '/media/leo/Data/checkpoints/350m_SmaLLMPro_Final/SmaLLMPro_Final.pt' |
| out_path_full = 'SmaLLMPro_350M.onnx' |
| device = 'cpu' |
|
|
| |
| print(f"Lade Checkpoint: {ckpt_path}") |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
|
|
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k, v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| |
| x = torch.randint(0, gptconf.vocab_size, (1, gptconf.block_size), dtype=torch.long) |
|
|
| |
| print("Exportiere sauberes ONNX Modell (Opset 18)...") |
| torch.onnx.export( |
| model, |
| (x,), |
| out_path_full, |
| export_params=True, |
| opset_version=18, |
| do_constant_folding=True, |
| input_names=['input'], |
| output_names=['logits'], |
| |
| |
| ) |
|
|
| |
| print("Erzwinge Speicherung in einer einzelnen Datei...") |
| try: |
| model_proto = onnx.load(out_path_full) |
| |
| onnx.save(model_proto, "SmaLLMPro_350M_Final.onnx") |
| print("✅ Full Precision Modell erfolgreich als Einzeldokument gespeichert: SmaLLMPro_350M_Final.onnx") |
| except Exception as e: |
| print(f"⚠️ Hinweis: Single-File Save fehlgeschlagen (evtl. doch über 2GB?). Fehler: {e}") |
|
|