Converter / app.py
vertalius's picture
Upload 2 files
1ab1659 verified
raw
history blame
4.27 kB
import gradio as gr
import os
import subprocess
import tempfile
from huggingface_hub import snapshot_download
def convert_model(model_name, quant_type):
try:
# 1. Загрузка модели с Hugging Face
model_dir = tempfile.mkdtemp()
snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False)
# 2. Определение типа модели (PyTorch или Keras)
model_files = os.listdir(model_dir)
if "model.ckpt" in model_files or ".pt" in " ".join(model_files):
# PyTorch → ONNX → TFLite
unet_path = os.path.join(model_dir, "unet", "diffusion_pytorch_model.bin")
if os.path.exists(unet_path):
# Экспорт UNet в ONNX (для Stable Diffusion)
os.chdir(model_dir)
subprocess.run([
"python", "-c",
f"from diffusers.models.unet_2d_condition import UNet2DConditionModel; "
f"import torch; "
f"model = UNet2DConditionModel.from_pretrained('{model_dir}', subfolder='unet'); "
f"dummy = torch.randn(1, 4, 64, 64); "
f"torch.onnx.export(model, dummy, 'unet.onnx')"
])
# ONNX → TFLite
subprocess.run([
"python", "-m", "tf2onnx.convert",
"--input", "unet.onnx",
"--output", "unet.tflite",
"--opset", "13"
])
tflite_path = os.path.join(model_dir, "unet.tflite")
else:
raise ValueError("Модель не содержит UNet-слои для конвертации")
elif "saved_model.pb" in model_files:
# TensorFlow SavedModel → TFLite
converter = tf.lite.TFLiteConverter.from_saved_model(model_dir)
if quant_type == "int8":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_dataset():
for _ in range(100):
yield [np.random.rand(1, 224, 224, 3).astype(np.float32)]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
elif quant_type == "float16":
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
tflite_path = os.path.join(model_dir, "model.tflite")
with open(tflite_path, "wb") as f:
f.write(tflite_model)
else:
# Использование Optimum CLI для Hugging Face моделей
output_path = os.path.join(model_dir, "model.tflite")
cmd = [
"optimum-cli", "export", "tflite",
"--model", model_name,
"--output", output_path
]
if quant_type == "int8":
cmd += ["--quantize", "int8"]
elif quant_type == "float16":
cmd += ["--quantize", "float16"]
subprocess.run(cmd)
tflite_path = output_path
# 3. Переименование в .task
task_path = tflite_path.replace(".tflite", ".task")
os.rename(tflite_path, task_path)
return task_path
except Exception as e:
return f"Ошибка: {str(e)}"
# Интерфейс Gradio
demo = gr.Interface(
fn=convert_model,
inputs=[
gr.Textbox(label="Модель Hugging Face (например, ImNoOne/f222-nsfw-inpainting-sd)"),
gr.Dropdown(choices=["no_quant", "int8", "float16"], label="Тип квантования")
],
outputs=gr.File(label="Скачать .task файл"),
title="Конвертер моделей в .task (TFLite)",
description="Конвертирует модели с Hugging Face и Civitai в формат .task с поддержкой квантования для Android."
)
demo.launch()