rahul7star commited on
Commit
74fec90
Β·
verified Β·
1 Parent(s): 2b513fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -5,6 +5,15 @@ import gradio as gr
5
  from huggingface_hub import snapshot_download
6
  from huggingface_hub import snapshot_download
7
  import spaces
 
 
 
 
 
 
 
 
 
8
 
9
  MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
10
  HF_REPO = "RaphaelLiu/PusaV1"
@@ -45,40 +54,39 @@ def download_model_subset():
45
 
46
 
47
  @spaces.GPU
48
- def generate_video(prompt):
49
- download_model_subset()
50
-
51
- temp_output_dir = "/tmp/pusa_video_output"
52
- os.makedirs(temp_output_dir, exist_ok=True)
53
-
54
- command = [
55
- "python", PUSA_SCRIPT_PATH,
56
- "--prompt", prompt,
57
- "--lora_path", FINAL_MODEL_PATH,
58
- "--output_dir", temp_output_dir
59
- ]
60
-
61
  try:
62
- print("πŸš€ Running inference...")
63
- subprocess.run(command, check=True)
 
 
 
 
 
64
 
65
- # Return first mp4 video found
66
- for file in os.listdir(temp_output_dir):
67
- if file.endswith(".mp4"):
68
- return os.path.join(temp_output_dir, file)
69
 
70
- return "❌ No video generated."
 
 
 
71
 
72
- except subprocess.CalledProcessError as e:
73
- return f"❌ Inference failed: {str(e)}"
74
 
 
 
 
75
 
 
76
  with gr.Blocks() as demo:
77
- gr.Markdown("## πŸ§˜β€β™‚οΈ PusaV1 Text-to-Video Generator (Wan2.1-T2V-14B)")
78
- prompt_input = gr.Textbox(label="Enter your prompt", lines=4, placeholder="A coral reef full of colorful fish...")
79
- generate_button = gr.Button("Generate Video")
 
 
80
  video_output = gr.Video(label="Generated Video")
81
 
82
- generate_button.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
83
 
84
  demo.launch()
 
5
  from huggingface_hub import snapshot_download
6
  from huggingface_hub import snapshot_download
7
  import spaces
8
+ # Add PusaV1 to path to resolve diffsynth imports
9
+ sys.path.append(os.path.abspath("PusaV1"))
10
+
11
+ # Import the actual model runner
12
+ from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
13
+
14
+ # Define paths
15
+ WAN_MODEL_DIR = "./model_zoo/Wan2.1-T2V-14B"
16
+ LORA_PATH = "./model_zoo/PusaV1/pusa_v1.pt"
17
 
18
  MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
19
  HF_REPO = "RaphaelLiu/PusaV1"
 
54
 
55
 
56
  @spaces.GPU
57
+ def generate_video(prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
58
  try:
59
+ # Load model manager
60
+ manager = ModelManager(base_model_dir=WAN_MODEL_DIR)
61
+ model = manager.load_model()
62
+
63
+ # Create video pipeline and apply LoRA
64
+ pipeline = WanVideoPusaPipeline(model=model)
65
+ pipeline.set_lora_adapters(LORA_PATH)
66
 
67
+ # Generate video
68
+ result = pipeline(prompt=prompt)
 
 
69
 
70
+ # Save video to a temporary file
71
+ tmp_dir = tempfile.mkdtemp()
72
+ video_path = os.path.join(tmp_dir, "output.mp4")
73
+ save_video(result, video_path)
74
 
75
+ return video_path
 
76
 
77
+ except Exception as e:
78
+ print(f"[ERROR] {e}")
79
+ return None
80
 
81
+ # Gradio UI
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("## πŸŽ₯ PusaV1 Text-to-Video Generator")
84
+ gr.Markdown("Describe a scene and generate a short video using Wan2.1-T2V + Pusa LoRA!")
85
+
86
+ prompt_input = gr.Textbox(label="Enter Prompt", lines=4, placeholder="E.g. A coral reef full of colorful fish...")
87
+ generate_btn = gr.Button("Generate Video")
88
  video_output = gr.Video(label="Generated Video")
89
 
90
+ generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
91
 
92
  demo.launch()