Spaces:
Running
Running
| from stashtag.models.onnx_predictor import ONNXStashtagPredictor | |
| from stashtag.utils.vtt_parser import process_sprite_frames | |
| import gradio as gr | |
| # Initialize ONNX predictor | |
| predictor = ONNXStashtagPredictor() | |
| def predict_tags(image, vtt, threshold=0.4): | |
| images, offsets, times, frames = process_sprite_frames(image, vtt) | |
| tags = predictor.predict_tags(images, threshold) | |
| # Add offset and time info back to tags | |
| for tag_name, tag_info in tags.items(): | |
| frame_idx = tag_info['frame'] | |
| if frame_idx < len(offsets): | |
| tag_info['offset'] = offsets[frame_idx] | |
| tag_info['time'] = times[frame_idx] | |
| return tags | |
| def predict_markers(image, vtt, threshold=0.4): | |
| images, offsets, times, frames = process_sprite_frames(image, vtt) | |
| return predictor.predict_markers(images, offsets, times, threshold) | |
| # Create modern Gradio 5 interface | |
| with gr.Blocks(title="Stashtag ONNX") as demo: | |
| gr.Markdown("# Stashtag - Video Frame Tagging (ONNX)") | |
| with gr.Tab("Tag Mode"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| tag_image = gr.Image(label="Sprite Image") | |
| tag_vtt = gr.Textbox(label="VTT file", lines=3) | |
| tag_threshold = gr.Number(value=0.4, label="Threshold") | |
| tag_submit = gr.Button("Predict Tags", variant="primary") | |
| with gr.Column(): | |
| tag_output = gr.JSON(label="Results") | |
| tag_submit.click( | |
| fn=predict_tags, | |
| inputs=[tag_image, tag_vtt, tag_threshold], | |
| outputs=[tag_output] | |
| ) | |
| with gr.Tab("Marker Mode"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| marker_image = gr.Image(label="Sprite Image") | |
| marker_vtt = gr.Textbox(label="VTT file", lines=3) | |
| marker_threshold = gr.Number(value=0.4, label="Threshold") | |
| marker_submit = gr.Button("Predict Markers", variant="primary") | |
| with gr.Column(): | |
| marker_output = gr.JSON(label="Results") | |
| marker_submit.click( | |
| fn=predict_markers, | |
| inputs=[marker_image, marker_vtt, marker_threshold], | |
| outputs=[marker_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") |