tinyInstruct / app_enhance.py
AItool's picture
Update app_enhance.py
1d4d2e2 verified
raw
history blame
2.31 kB
import os
import uuid
import subprocess
from pathlib import Path
from typing import Tuple
import gradio as gr
# Constants
TMP_DIR = "/tmp/gradio/output/"
GIF_EXT = "gif"
PALETTE_PATH = "/tmp/gradio/palette.png"
def ensure_tmp_dir():
os.makedirs(TMP_DIR, exist_ok=True)
def generate_interpolation_frames(img_a: str, img_b: str, exp: int = 4):
"""Runs inference_img.py to generate interpolated frames"""
cmd = [
"python3", "inference_img.py",
"--img", img_a, img_b,
"--exp", str(exp)
]
subprocess.run(cmd, check=True)
def create_palette():
"""Generates GIF palette from interpolated frames"""
cmd = [
"ffmpeg", "-y", "-r", "14", "-f", "image2",
"-i", f"{TMP_DIR}img%d.png",
"-vf", "palettegen=stats_mode=single",
PALETTE_PATH
]
subprocess.run(cmd, check=True)
def write_gif(gif_path: str):
"""Creates final interpolated GIF using palette"""
cmd = [
"ffmpeg", "-y", "-r", "14", "-f", "image2",
"-i", f"{TMP_DIR}img%d.png",
"-i", PALETTE_PATH,
"-lavfi", "paletteuse",
gif_path
]
subprocess.run(cmd, check=True)
def enhance_image(img_a: str, img_b: str, mode: str) -> Tuple[str, str]:
ensure_tmp_dir()
gif_path = f"{TMP_DIR}{uuid.uuid4()}.{GIF_EXT}"
try:
generate_interpolation_frames(img_a, img_b)
create_palette()
write_gif(gif_path)
return gif_path, gif_path
except subprocess.CalledProcessError as e:
raise gr.Error(f"Interpolation failed: {e}")
# Gradio UI
def build_interface():
with gr.Blocks(title="RIFE Interpolation") as demo:
with gr.Row():
input_imageA = gr.Image(label="Image A", type="filepath")
input_imageB = gr.Image(label="Image B", type="filepath")
enhance_mode = gr.Dropdown(choices=["default"], value="default", label="Mode")
output_image = gr.Image(label="Result GIF", type="filepath")
output_path = gr.Textbox(label="GIF Path", interactive=False)
g_btn = gr.Button("Interpolate")
g_btn.click(
fn=enhance_image,
inputs=[input_imageA, input_imageB, enhance_mode],
outputs=[output_image, output_path]
)
return demo
demo = build_interface()
demo.launch()