stashtag_onnx / app.py
cc1234's picture
init
f6da4f5
from stashtag.models.onnx_predictor import ONNXStashtagPredictor
from stashtag.utils.vtt_parser import process_sprite_frames
import gradio as gr
# Initialize ONNX predictor
predictor = ONNXStashtagPredictor()
def predict_tags(image, vtt, threshold=0.4):
images, offsets, times, frames = process_sprite_frames(image, vtt)
tags = predictor.predict_tags(images, threshold)
# Add offset and time info back to tags
for tag_name, tag_info in tags.items():
frame_idx = tag_info['frame']
if frame_idx < len(offsets):
tag_info['offset'] = offsets[frame_idx]
tag_info['time'] = times[frame_idx]
return tags
def predict_markers(image, vtt, threshold=0.4):
images, offsets, times, frames = process_sprite_frames(image, vtt)
return predictor.predict_markers(images, offsets, times, threshold)
# Create modern Gradio 5 interface
with gr.Blocks(title="Stashtag ONNX") as demo:
gr.Markdown("# Stashtag - Video Frame Tagging (ONNX)")
with gr.Tab("Tag Mode"):
with gr.Row():
with gr.Column():
tag_image = gr.Image(label="Sprite Image")
tag_vtt = gr.Textbox(label="VTT file", lines=3)
tag_threshold = gr.Number(value=0.4, label="Threshold")
tag_submit = gr.Button("Predict Tags", variant="primary")
with gr.Column():
tag_output = gr.JSON(label="Results")
tag_submit.click(
fn=predict_tags,
inputs=[tag_image, tag_vtt, tag_threshold],
outputs=[tag_output]
)
with gr.Tab("Marker Mode"):
with gr.Row():
with gr.Column():
marker_image = gr.Image(label="Sprite Image")
marker_vtt = gr.Textbox(label="VTT file", lines=3)
marker_threshold = gr.Number(value=0.4, label="Threshold")
marker_submit = gr.Button("Predict Markers", variant="primary")
with gr.Column():
marker_output = gr.JSON(label="Results")
marker_submit.click(
fn=predict_markers,
inputs=[marker_image, marker_vtt, marker_threshold],
outputs=[marker_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")