Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import subprocess | |
| import tempfile | |
| from huggingface_hub import snapshot_download | |
| def convert_model(model_name, quant_type): | |
| try: | |
| # 1. Загрузка модели с Hugging Face | |
| model_dir = tempfile.mkdtemp() | |
| snapshot_download(repo_id=model_name, local_dir=model_dir, local_dir_use_symlinks=False) | |
| # 2. Определение типа модели (PyTorch или Keras) | |
| model_files = os.listdir(model_dir) | |
| if "model.ckpt" in model_files or ".pt" in " ".join(model_files): | |
| # PyTorch → ONNX → TFLite | |
| unet_path = os.path.join(model_dir, "unet", "diffusion_pytorch_model.bin") | |
| if os.path.exists(unet_path): | |
| # Экспорт UNet в ONNX (для Stable Diffusion) | |
| os.chdir(model_dir) | |
| subprocess.run([ | |
| "python", "-c", | |
| f"from diffusers.models.unet_2d_condition import UNet2DConditionModel; " | |
| f"import torch; " | |
| f"model = UNet2DConditionModel.from_pretrained('{model_dir}', subfolder='unet'); " | |
| f"dummy = torch.randn(1, 4, 64, 64); " | |
| f"torch.onnx.export(model, dummy, 'unet.onnx')" | |
| ]) | |
| # ONNX → TFLite | |
| subprocess.run([ | |
| "python", "-m", "tf2onnx.convert", | |
| "--input", "unet.onnx", | |
| "--output", "unet.tflite", | |
| "--opset", "13" | |
| ]) | |
| tflite_path = os.path.join(model_dir, "unet.tflite") | |
| else: | |
| raise ValueError("Модель не содержит UNet-слои для конвертации") | |
| elif "saved_model.pb" in model_files: | |
| # TensorFlow SavedModel → TFLite | |
| converter = tf.lite.TFLiteConverter.from_saved_model(model_dir) | |
| if quant_type == "int8": | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| def representative_dataset(): | |
| for _ in range(100): | |
| yield [np.random.rand(1, 224, 224, 3).astype(np.float32)] | |
| converter.representative_dataset = representative_dataset | |
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
| elif quant_type == "float16": | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_types = [tf.float16] | |
| tflite_model = converter.convert() | |
| tflite_path = os.path.join(model_dir, "model.tflite") | |
| with open(tflite_path, "wb") as f: | |
| f.write(tflite_model) | |
| else: | |
| # Использование Optimum CLI для Hugging Face моделей | |
| output_path = os.path.join(model_dir, "model.tflite") | |
| cmd = [ | |
| "optimum-cli", "export", "tflite", | |
| "--model", model_name, | |
| "--output", output_path | |
| ] | |
| if quant_type == "int8": | |
| cmd += ["--quantize", "int8"] | |
| elif quant_type == "float16": | |
| cmd += ["--quantize", "float16"] | |
| subprocess.run(cmd) | |
| tflite_path = output_path | |
| # 3. Переименование в .task | |
| task_path = tflite_path.replace(".tflite", ".task") | |
| os.rename(tflite_path, task_path) | |
| return task_path | |
| except Exception as e: | |
| return f"Ошибка: {str(e)}" | |
| # Интерфейс Gradio | |
| demo = gr.Interface( | |
| fn=convert_model, | |
| inputs=[ | |
| gr.Textbox(label="Модель Hugging Face (например, ImNoOne/f222-nsfw-inpainting-sd)"), | |
| gr.Dropdown(choices=["no_quant", "int8", "float16"], label="Тип квантования") | |
| ], | |
| outputs=gr.File(label="Скачать .task файл"), | |
| title="Конвертер моделей в .task (TFLite)", | |
| description="Конвертирует модели с Hugging Face и Civitai в формат .task с поддержкой квантования для Android." | |
| ) | |
| demo.launch() |