Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from controlnet_aux import OpenposeDetector | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from moviepy.editor import * | |
| # Load the OpenPose detector | |
| openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet') | |
| def get_frames(video_in): | |
| """Extract frames from a video and resize to height 512px. | |
| Args: | |
| video_in: Path to the input video file. | |
| Returns: | |
| frames: List of paths to extracted frame images. | |
| fps: Frames per second of the original video. | |
| """ | |
| frames = [] | |
| clip = VideoFileClip(video_in) | |
| if clip.fps > 30: | |
| print("video rate is over 30, resetting to 30") | |
| clip_resized = clip.resize(height=512) | |
| clip_resized.write_videofile("video_resized.mp4", fps=30) | |
| else: | |
| print("video rate is OK") | |
| clip_resized = clip.resize(height=512) | |
| clip_resized.write_videofile("video_resized.mp4", fps=clip.fps) | |
| cap = cv2.VideoCapture("video_resized.mp4") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| i = 0 | |
| while(cap.isOpened()): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_path = f'frame_{i}.jpg' | |
| cv2.imwrite(frame_path, frame) | |
| frames.append(frame_path) | |
| i += 1 | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| return frames, fps | |
| def get_openpose_filter(i): | |
| """Apply OpenPose filter to a single image frame. | |
| Args: | |
| i: Path to the image frame. | |
| Returns: | |
| output_path: Path to the OpenPose-processed image. | |
| """ | |
| image = Image.open(i) | |
| image = openpose(image) | |
| output_path = f"openpose_frame_{os.path.basename(i)}.jpeg" | |
| image.save(output_path) | |
| return output_path | |
| def create_video(frames, fps, type): | |
| """Create a video from a sequence of image frames. | |
| Args: | |
| frames: List of image frame paths. | |
| fps: Frames per second for the output video. | |
| type: A string used as the prefix for naming the result video. | |
| Returns: | |
| video_path: Path to the resulting video file. | |
| """ | |
| clip = ImageSequenceClip(frames, fps=fps) | |
| video_path = f"{type}_result.mp4" | |
| clip.write_videofile(video_path, fps=fps) | |
| return video_path | |
| def convertG2V(imported_gif): | |
| """Convert a GIF file to a standard MP4 video. | |
| Args: | |
| imported_gif: The input GIF file object. | |
| Returns: | |
| Path to the converted MP4 video file. | |
| """ | |
| clip = VideoFileClip(imported_gif.name) | |
| clip.write_videofile("my_gif_video.mp4") | |
| return "my_gif_video.mp4" | |
| def infer(video_in): | |
| """Generate an OpenPose-filtered video from an input video. | |
| This function performs the following steps: | |
| 1. Extracts frames from the input video and retrieves the original frame rate (FPS). | |
| 2. Applies the OpenPose model to each frame to extract pose data. | |
| 3. Reconstructs a new video from the OpenPose-processed frames using the original FPS. | |
| Args: | |
| video_in: The uploaded input video file (MP4 or converted GIF). | |
| Returns: | |
| final_vid: The path to the OpenPose-filtered output video. | |
| files: A list containing the output video file (for download). | |
| """ | |
| break_vid = get_frames(video_in) | |
| frames_list = break_vid[0] | |
| fps = break_vid[1] | |
| n_frame = len(frames_list) | |
| result_frames = [] | |
| for i in frames_list[:n_frame]: | |
| openpose_frame = get_openpose_filter(i) | |
| result_frames.append(openpose_frame) | |
| final_vid = create_video(result_frames, fps, "openpose") | |
| files = [final_vid] | |
| return final_vid, files | |
| # UI layout | |
| title = """ | |
| <div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
| <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; margin-bottom: 10px;"> | |
| <h1 style="font-weight: 600; margin-bottom: 7px;">Video to OpenPose</h1> | |
| </div> | |
| </div> | |
| """ | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(sources=["upload"]) | |
| gif_input = gr.File(label="Import a GIF instead", file_types=['.gif']) | |
| gif_input.change(fn=convertG2V, inputs=gif_input, outputs=video_input) | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| video_output = gr.Video() | |
| file_output = gr.Files() | |
| submit_btn.click(fn=infer, inputs=[video_input], outputs=[video_output, file_output]) | |
| demo.launch(mcp_server=True) | |