Spaces:
Runtime error
Runtime error
File size: 5,320 Bytes
6f7e476 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import torch
from safetensors.torch import save_file as torch_save_file
import tensorflow as tf
from safetensors.keras import save_model as keras_save_model
import os
import tempfile
def convert_to_safetensors(framework, model_file):
"""
Convert uploaded model files to SafeTensors format
"""
if not model_file:
return gr.Error("Please upload a model file.")
# Create a temporary output file
output_filename = "model.safetensors"
try:
if framework == "PyTorch":
# Load PyTorch model weights safely
state_dict = torch.load(
model_file,
map_location='cpu',
weights_only=True
)
# Handle case where full model is loaded instead of just state_dict
if hasattr(state_dict, 'state_dict'):
state_dict = state_dict.state_dict()
elif isinstance(state_dict, torch.nn.Module):
state_dict = state_dict.state_dict()
# Save to SafeTensors format
torch_save_file(state_dict, output_filename)
return output_filename
elif framework == "TensorFlow":
# Load TensorFlow/Keras model
model = tf.keras.models.load_model(model_file)
# Save to SafeTensors format
keras_save_model(model, output_filename)
return output_filename
else:
return gr.Error("Please select a valid framework (PyTorch or TensorFlow).")
except Exception as e:
error_msg = f"{framework} Conversion Error: {str(e)}"
if framework == "PyTorch":
error_msg += "\n\nTips:\nβ’ Ensure the file is a valid PyTorch model (.pt, .pth)\nβ’ Model should contain state_dict or be loadable with torch.load()"
elif framework == "TensorFlow":
error_msg += "\n\nTips:\nβ’ Ensure the file is a valid TensorFlow model (.h5, SavedModel)\nβ’ For SavedModel format, upload as a zip file containing the model directory"
return gr.Error(error_msg)
# Create the Gradio interface
with gr.Blocks(
title="SafeTensors Model Converter",
theme=gr.themes.Soft()
) as iface:
gr.Markdown("""
# π No-Code SafeTensors Model Creator
Convert your machine learning models to the secure **SafeTensors** format with zero coding required!
## Why SafeTensors?
- **Security**: Prevents arbitrary code execution during model loading
- **Speed**: Faster loading times compared to pickle-based formats
- **Memory Efficiency**: Zero-copy deserialization
- **Cross-Platform**: Works across different ML frameworks
## Supported Formats
- **PyTorch**: `.pt`, `.pth` files containing model weights
- **TensorFlow**: `.h5` files or SavedModel directories (as zip)
""")
with gr.Row():
with gr.Column():
framework_dropdown = gr.Dropdown(
choices=["PyTorch", "TensorFlow"],
label="π§ Select Framework",
info="Choose the framework your model was trained with",
value="PyTorch"
)
model_upload = gr.File(
label="π Upload Model File",
file_types=[".pt", ".pth", ".h5", ".zip"],
info="Upload your model file (.pt/.pth for PyTorch, .h5 for TensorFlow)"
)
convert_btn = gr.Button(
"π Convert to SafeTensors",
variant="primary",
size="lg"
)
with gr.Column():
output_file = gr.File(
label="πΎ Download SafeTensors File",
info="Your converted model will appear here"
)
gr.Markdown("""
### π Usage Instructions
1. **Select Framework**: Choose PyTorch or TensorFlow
2. **Upload Model**: Select your model file from your computer
3. **Convert**: Click the convert button
4. **Download**: Get your secure SafeTensors file
### β οΈ Important Notes
- Only model weights are converted (no training code)
- Original model architecture code is still needed for inference
- Conversion preserves all tensor data and metadata
""")
# Set up the conversion event
convert_btn.click(
fn=convert_to_safetensors,
inputs=[framework_dropdown, model_upload],
outputs=output_file,
show_progress=True
)
gr.Markdown("""
---
### π‘οΈ Security Benefits
SafeTensors format eliminates security risks associated with pickle-based model formats by:
- Storing only tensor data (no executable code)
- Using a simple, well-defined file format
- Enabling safe model sharing and deployment
### π Learn More
- [SafeTensors Documentation](https://huggingface.co/docs/safetensors)
- [Hugging Face Model Hub](https://huggingface.co/models)
""")
# For Hugging Face Spaces deployment
if __name__ == "__main__":
iface.launch()
|