File size: 4,274 Bytes
1ab1659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import os
import subprocess
import tempfile
from huggingface_hub import snapshot_download

def convert_model(model_name, quant_type):
    try:
        # 1. Загрузка модели с Hugging Face
        model_dir = tempfile.mkdtemp()
        snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False)
        
        # 2. Определение типа модели (PyTorch или Keras)
        model_files = os.listdir(model_dir)
        if "model.ckpt" in model_files or ".pt" in " ".join(model_files):
            # PyTorch → ONNX → TFLite
            unet_path = os.path.join(model_dir, "unet", "diffusion_pytorch_model.bin")
            if os.path.exists(unet_path):
                # Экспорт UNet в ONNX (для Stable Diffusion)
                os.chdir(model_dir)
                subprocess.run([
                    "python", "-c",
                    f"from diffusers.models.unet_2d_condition import UNet2DConditionModel; "
                    f"import torch; "
                    f"model = UNet2DConditionModel.from_pretrained('{model_dir}', subfolder='unet'); "
                    f"dummy = torch.randn(1, 4, 64, 64); "
                    f"torch.onnx.export(model, dummy, 'unet.onnx')"
                ])
                # ONNX → TFLite
                subprocess.run([
                    "python", "-m", "tf2onnx.convert",
                    "--input", "unet.onnx",
                    "--output", "unet.tflite",
                    "--opset", "13"
                ])
                tflite_path = os.path.join(model_dir, "unet.tflite")
            else:
                raise ValueError("Модель не содержит UNet-слои для конвертации")
        elif "saved_model.pb" in model_files:
            # TensorFlow SavedModel → TFLite
            converter = tf.lite.TFLiteConverter.from_saved_model(model_dir)
            if quant_type == "int8":
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                def representative_dataset():
                    for _ in range(100):
                        yield [np.random.rand(1, 224, 224, 3).astype(np.float32)]
                converter.representative_dataset = representative_dataset
                converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
            elif quant_type == "float16":
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                converter.target_spec.supported_types = [tf.float16]
            tflite_model = converter.convert()
            tflite_path = os.path.join(model_dir, "model.tflite")
            with open(tflite_path, "wb") as f:
                f.write(tflite_model)
        else:
            # Использование Optimum CLI для Hugging Face моделей
            output_path = os.path.join(model_dir, "model.tflite")
            cmd = [
                "optimum-cli", "export", "tflite",
                "--model", model_name,
                "--output", output_path
            ]
            if quant_type == "int8":
                cmd += ["--quantize", "int8"]
            elif quant_type == "float16":
                cmd += ["--quantize", "float16"]
            subprocess.run(cmd)
            tflite_path = output_path

        # 3. Переименование в .task
        task_path = tflite_path.replace(".tflite", ".task")
        os.rename(tflite_path, task_path)
        return task_path

    except Exception as e:
        return f"Ошибка: {str(e)}"

# Интерфейс Gradio
demo = gr.Interface(
    fn=convert_model,
    inputs=[
        gr.Textbox(label="Модель Hugging Face (например, ImNoOne/f222-nsfw-inpainting-sd)"),
        gr.Dropdown(choices=["no_quant", "int8", "float16"], label="Тип квантования")
    ],
    outputs=gr.File(label="Скачать .task файл"),
    title="Конвертер моделей в .task (TFLite)",
    description="Конвертирует модели с Hugging Face и Civitai в формат .task с поддержкой квантования для Android."
)

demo.launch()