arif670 commited on
Commit
7c9e1db
·
verified ·
1 Parent(s): 16d1003

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -21
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Environment config at VERY TOP
2
  import os
3
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
4
  os.environ['FONTCONFIG_PATH'] = '/tmp/fontconfig'
@@ -10,26 +9,24 @@ import logging
10
  from models import load_models
11
  from video_generator import generate_video_pipeline
12
 
13
- # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
  # Hardware configuration
18
- if not torch.cuda.is_available():
19
- logger.warning("Using CPU-only mode")
20
- torch.set_num_threads(4)
21
 
22
  # Load models
23
  try:
24
  text_to_image, image_to_video, tts_model = load_models()
25
  except Exception as e:
26
- logger.error(f"Model load failed: {str(e)}")
27
  raise
28
 
29
  def generate_video(prompt, duration=5, fps=24):
30
  with tempfile.TemporaryDirectory() as tmpdir:
31
  try:
32
- video_path = generate_video_pipeline(
33
  prompt=prompt,
34
  text_to_image_model=text_to_image,
35
  image_to_video_model=image_to_video,
@@ -38,36 +35,31 @@ def generate_video(prompt, duration=5, fps=24):
38
  duration=duration,
39
  fps=fps
40
  )
41
- return video_path
42
  except Exception as e:
43
  logger.error(f"Generation failed: {str(e)}")
44
- raise gr.Error(f"Video creation failed: {str(e)}")
45
 
46
- # Gradio interface
47
  with gr.Blocks(title="AI Video Generator") as app:
48
  gr.Markdown("# 🎥 AI Video Generator")
49
 
50
  with gr.Row():
51
- prompt_input = gr.Textbox(label="Input Prompt",
52
- placeholder="A cat walking through a forest...")
53
 
54
  with gr.Row():
55
- duration = gr.Slider(2, 60, value=5, label="Duration (seconds)")
56
- fps = gr.Slider(12, 60, value=24, step=1, label="FPS")
57
 
58
- generate_btn = gr.Button("Generate Video", variant="primary")
59
 
60
  with gr.Row():
61
  video_output = gr.Video(label="Result", format="mp4")
62
- download_btn = gr.File(label="Download Video", type="file", interactive=False)
63
-
64
- progress = gr.Progress()
65
 
66
  generate_btn.click(
67
- fn=generate_video,
68
  inputs=[prompt_input, duration, fps],
69
- outputs=[video_output, download_btn],
70
  )
71
 
72
  if __name__ == "__main__":
73
- app.launch(server_name="0.0.0.0", share=False)
 
 
1
  import os
2
  os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
3
  os.environ['FONTCONFIG_PATH'] = '/tmp/fontconfig'
 
9
  from models import load_models
10
  from video_generator import generate_video_pipeline
11
 
 
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  # Hardware configuration
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ logger.info(f"Running on {device.upper()}")
 
18
 
19
  # Load models
20
  try:
21
  text_to_image, image_to_video, tts_model = load_models()
22
  except Exception as e:
23
+ logger.error(f"Initialization failed: {str(e)}")
24
  raise
25
 
26
  def generate_video(prompt, duration=5, fps=24):
27
  with tempfile.TemporaryDirectory() as tmpdir:
28
  try:
29
+ return generate_video_pipeline(
30
  prompt=prompt,
31
  text_to_image_model=text_to_image,
32
  image_to_video_model=image_to_video,
 
35
  duration=duration,
36
  fps=fps
37
  )
 
38
  except Exception as e:
39
  logger.error(f"Generation failed: {str(e)}")
40
+ raise gr.Error(f"Error: {str(e)}")
41
 
 
42
  with gr.Blocks(title="AI Video Generator") as app:
43
  gr.Markdown("# 🎥 AI Video Generator")
44
 
45
  with gr.Row():
46
+ prompt_input = gr.Textbox(label="Prompt", placeholder="A cat in space...")
 
47
 
48
  with gr.Row():
49
+ duration = gr.Slider(2, 30, 5, label="Duration (s)")
50
+ fps = gr.Slider(12, 60, 24, label="FPS")
51
 
52
+ generate_btn = gr.Button("Generate", variant="primary")
53
 
54
  with gr.Row():
55
  video_output = gr.Video(label="Result", format="mp4")
56
+ download_btn = gr.File(label="Download", type="file")
 
 
57
 
58
  generate_btn.click(
59
+ generate_video,
60
  inputs=[prompt_input, duration, fps],
61
+ outputs=[video_output, download_btn]
62
  )
63
 
64
  if __name__ == "__main__":
65
+ app.launch(server_name="0.0.0.0")