LimaRaed commited on
Commit
2f7fdab
·
verified ·
1 Parent(s): 96c9fdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -9,9 +9,8 @@ warnings.filterwarnings("ignore")
9
 
10
  # Set to use CPU
11
  torch_device = "cpu"
12
- torch_dtype = torch.float32 # Use float32 for CPU stability
13
 
14
- # Load a lightweight model
15
  def load_model():
16
  model_id = "damo-vilab/text-to-video-ms-1.7b"
17
  pipe = DiffusionPipeline.from_pretrained(
@@ -19,77 +18,73 @@ def load_model():
19
  torch_dtype=torch_dtype
20
  )
21
  pipe = pipe.to(torch_device)
22
- pipe.enable_attention_slicing() # Reduce memory usage
23
  return pipe
24
 
25
  def generate_video(prompt, num_frames=8, num_inference_steps=20):
26
  start_time = time.time()
27
 
28
- # Load model with caching
29
  if not hasattr(generate_video, "pipe"):
30
  generate_video.pipe = load_model()
31
 
32
- # Generate with lower resolution and fewer frames for CPU
33
  with torch.no_grad():
34
  output = generate_video.pipe(
35
  prompt,
36
- num_frames=min(num_frames, 8), # Keep frames low for CPU
37
  num_inference_steps=min(num_inference_steps, 20),
38
- height=256, # Lower resolution
39
  width=256
40
  )
41
 
42
- # Convert numpy arrays to PIL Images
43
- frames = [Image.fromarray((frame * 255).astype(np.uint8)) for frame in output.frames]
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Create GIF
46
  gif_path = "output.gif"
47
- duration = max(1000 // 3, 100) # Minimum 100ms per frame
48
  frames[0].save(
49
  gif_path,
50
  save_all=True,
51
  append_images=frames[1:],
52
- duration=duration,
53
  loop=0,
54
- save_format='GIF'
55
  )
56
 
57
- gen_time = time.time() - start_time
58
- print(f"Generation took {gen_time:.2f} seconds")
59
  return gif_path
60
 
61
  # Gradio Interface
62
  with gr.Blocks(title="CPU Text-to-Video") as demo:
63
  gr.Markdown("# 🐢 CPU Text-to-Video Generator")
64
- gr.Markdown("This version runs entirely on CPU - generations will be slower and lower quality")
65
 
66
  with gr.Row():
67
  with gr.Column():
68
- prompt = gr.Textbox(label="Prompt", placeholder="A fish swimming in space")
69
  with gr.Accordion("Advanced Options", open=False):
70
  frames = gr.Slider(4, 12, value=8, step=4, label="Frames")
71
  steps = gr.Slider(10, 30, value=20, step=5, label="Steps")
72
- submit = gr.Button("Generate", variant="primary")
73
 
74
  with gr.Column():
75
  output = gr.Image(label="Result", format="gif")
76
- gr.Markdown("Note: On CPU, generation may take 5-15 minutes")
77
-
78
- examples = gr.Examples(
79
- examples=[
80
- ["A paper boat floating on water"],
81
- ["A sloth wearing sunglasses"],
82
- ["A candle flame in the wind"]
83
- ],
84
- inputs=prompt,
85
- label="Try these examples"
86
- )
87
 
88
  submit.click(
89
  fn=generate_video,
90
  inputs=[prompt, frames, steps],
91
- outputs=output,
92
- api_name="generate"
93
  )
94
 
95
- demo.launch(show_api=False)
 
9
 
10
  # Set to use CPU
11
  torch_device = "cpu"
12
+ torch_dtype = torch.float32
13
 
 
14
  def load_model():
15
  model_id = "damo-vilab/text-to-video-ms-1.7b"
16
  pipe = DiffusionPipeline.from_pretrained(
 
18
  torch_dtype=torch_dtype
19
  )
20
  pipe = pipe.to(torch_device)
21
+ pipe.enable_attention_slicing()
22
  return pipe
23
 
24
  def generate_video(prompt, num_frames=8, num_inference_steps=20):
25
  start_time = time.time()
26
 
 
27
  if not hasattr(generate_video, "pipe"):
28
  generate_video.pipe = load_model()
29
 
 
30
  with torch.no_grad():
31
  output = generate_video.pipe(
32
  prompt,
33
+ num_frames=min(num_frames, 8),
34
  num_inference_steps=min(num_inference_steps, 20),
35
+ height=256,
36
  width=256
37
  )
38
 
39
+ # Correct frame conversion - handle the 4D array properly
40
+ video_frames = output.frames
41
+ if isinstance(video_frames, np.ndarray):
42
+ # Reshape from (1, num_frames, height, width, 3) to (num_frames, height, width, 3)
43
+ if video_frames.ndim == 5:
44
+ video_frames = video_frames[0] # Remove batch dimension
45
+
46
+ frames = []
47
+ for frame in video_frames:
48
+ # Convert to 8-bit and ensure correct channel order
49
+ frame = (frame * 255).astype(np.uint8)
50
+ frames.append(Image.fromarray(frame))
51
+ else:
52
+ raise ValueError("Unexpected frame format")
53
 
54
  # Create GIF
55
  gif_path = "output.gif"
 
56
  frames[0].save(
57
  gif_path,
58
  save_all=True,
59
  append_images=frames[1:],
60
+ duration=100, # 100ms per frame
61
  loop=0,
62
+ quality=80
63
  )
64
 
65
+ print(f"Generation took {time.time() - start_time:.2f} seconds")
 
66
  return gif_path
67
 
68
  # Gradio Interface
69
  with gr.Blocks(title="CPU Text-to-Video") as demo:
70
  gr.Markdown("# 🐢 CPU Text-to-Video Generator")
 
71
 
72
  with gr.Row():
73
  with gr.Column():
74
+ prompt = gr.Textbox(label="Prompt")
75
  with gr.Accordion("Advanced Options", open=False):
76
  frames = gr.Slider(4, 12, value=8, step=4, label="Frames")
77
  steps = gr.Slider(10, 30, value=20, step=5, label="Steps")
78
+ submit = gr.Button("Generate")
79
 
80
  with gr.Column():
81
  output = gr.Image(label="Result", format="gif")
82
+ gr.Markdown("Note: CPU generation may take several minutes")
 
 
 
 
 
 
 
 
 
 
83
 
84
  submit.click(
85
  fn=generate_video,
86
  inputs=[prompt, frames, steps],
87
+ outputs=output
 
88
  )
89
 
90
+ demo.launch()