aiwithshweta commited on
Commit
bcea283
·
verified ·
1 Parent(s): 3557b0a
Files changed (1) hide show
  1. app.py +22 -13
app.py CHANGED
@@ -5,10 +5,9 @@ import gradio as gr
5
  import zipfile
6
 
7
  # ------------------------------------------------------
8
- # 1) DOWNLOAD SadTalker ZIP FROM GOOGLE DRIVE
9
  # ------------------------------------------------------
10
 
11
- # ← Put your ZIP FILE ID here
12
  SADTALKER_ZIP_ID = "1ERCCvqt2YTNiwfqY1N95aQIIBm3maApu"
13
 
14
  if not os.path.exists("SadTalker"):
@@ -19,17 +18,15 @@ if not os.path.exists("SadTalker"):
19
  with zipfile.ZipFile("sad.zip", 'r') as zip_ref:
20
  zip_ref.extractall("./")
21
 
22
- # Detect extracted folder name
23
  if os.path.exists("SadTalker-main"):
24
  os.rename("SadTalker-main", "SadTalker")
25
- elif os.path.exists("SadTalker"):
26
- pass
27
 
28
  print("✔ SadTalker ready!")
29
  os.chdir("SadTalker")
30
 
31
  # ------------------------------------------------------
32
- # 2) DOWNLOAD CHECKPOINT MODELS (YOUR OLD CODE)
33
  # ------------------------------------------------------
34
 
35
  def gd(file_id, out):
@@ -42,7 +39,6 @@ def gd(file_id, out):
42
 
43
  os.makedirs("checkpoints", exist_ok=True)
44
 
45
- # Your original model IDs (replace with yours)
46
  gd("1g3VIpU3yhpITMZtrWU2mbmyzLoF9K3Gz", "checkpoints/audio2exp_00300-model.pth")
47
  gd("1Jp4i_Qc-6qCms7v1kN61RE3qnT_9J8vg", "checkpoints/audio2pose_00140-model.pth")
48
  gd("1pTbKmpOeWRSA1NYQ3DnWe3W-9Qeyy7PK", "checkpoints/mapping_00109-model.pth.tar")
@@ -54,7 +50,7 @@ gd("1cag0u7e5RgdKoxBa7exY599VNGtYgNQb", "checkpoints/shape_predictor_68_face_lan
54
  print("✔ All checkpoints ready!")
55
 
56
  # ------------------------------------------------------
57
- # 3) GENERATE VIDEO (Auto-detect MP4)
58
  # ------------------------------------------------------
59
 
60
  def generate_video(image, audio):
@@ -69,12 +65,13 @@ def generate_video(image, audio):
69
  "--size", "256",
70
  "--enhancer", "None",
71
  "--expression_scale", "1.8",
72
- "--preprocess", "crop"
73
  ]
74
 
 
75
  subprocess.run(cmd)
76
 
77
- # find latest mp4
78
  mp4s = []
79
  for root, dirs, files in os.walk("results"):
80
  for f in files:
@@ -84,12 +81,24 @@ def generate_video(image, audio):
84
  if not mp4s:
85
  return "❌ No video generated."
86
 
 
87
  mp4s.sort(key=lambda x: os.path.getmtime(x), reverse=True)
88
  latest = mp4s[0]
89
 
90
- import shutil
91
- shutil.copy(latest, "output.mp4")
92
 
93
- return "output.mp4"
94
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
 
5
  import zipfile
6
 
7
  # ------------------------------------------------------
8
+ # 1) DOWNLOAD SadTalker ZIP FROM GOOGLE DRIVE (ONLY ONCE)
9
  # ------------------------------------------------------
10
 
 
11
  SADTALKER_ZIP_ID = "1ERCCvqt2YTNiwfqY1N95aQIIBm3maApu"
12
 
13
  if not os.path.exists("SadTalker"):
 
18
  with zipfile.ZipFile("sad.zip", 'r') as zip_ref:
19
  zip_ref.extractall("./")
20
 
21
+ # Rename folder if extracted as SadTalker-main
22
  if os.path.exists("SadTalker-main"):
23
  os.rename("SadTalker-main", "SadTalker")
 
 
24
 
25
  print("✔ SadTalker ready!")
26
  os.chdir("SadTalker")
27
 
28
  # ------------------------------------------------------
29
+ # 2) DOWNLOAD CHECKPOINT MODELS (ONLY IF MISSING)
30
  # ------------------------------------------------------
31
 
32
  def gd(file_id, out):
 
39
 
40
  os.makedirs("checkpoints", exist_ok=True)
41
 
 
42
  gd("1g3VIpU3yhpITMZtrWU2mbmyzLoF9K3Gz", "checkpoints/audio2exp_00300-model.pth")
43
  gd("1Jp4i_Qc-6qCms7v1kN61RE3qnT_9J8vg", "checkpoints/audio2pose_00140-model.pth")
44
  gd("1pTbKmpOeWRSA1NYQ3DnWe3W-9Qeyy7PK", "checkpoints/mapping_00109-model.pth.tar")
 
50
  print("✔ All checkpoints ready!")
51
 
52
  # ------------------------------------------------------
53
+ # 3) GENERATE VIDEO (FAST + ALWAYS RETURNS OUTPUT)
54
  # ------------------------------------------------------
55
 
56
  def generate_video(image, audio):
 
65
  "--size", "256",
66
  "--enhancer", "None",
67
  "--expression_scale", "1.8",
68
+ "--preprocess", "full"
69
  ]
70
 
71
+ print("▶ Running inference...")
72
  subprocess.run(cmd)
73
 
74
+ # Find all mp4 in results
75
  mp4s = []
76
  for root, dirs, files in os.walk("results"):
77
  for f in files:
 
81
  if not mp4s:
82
  return "❌ No video generated."
83
 
84
+ # Pick latest file
85
  mp4s.sort(key=lambda x: os.path.getmtime(x), reverse=True)
86
  latest = mp4s[0]
87
 
88
+ print("🎬 Video generated:", latest)
89
+ return latest
90
 
 
91
 
92
+ # ------------------------------------------------------
93
+ # 4) GRADIO UI
94
+ # ------------------------------------------------------
95
+
96
+ demo = gr.Interface(
97
+ fn=generate_video,
98
+ inputs=[gr.Image(type="pil"), gr.Audio(type="filepath")],
99
+ outputs=gr.Video(),
100
+ title="SadTalker (Google Drive Version)",
101
+ description="Fast loading + no duplicate downloads + auto video return"
102
+ )
103
 
104
+ demo.launch()