darshankr commited on
Commit
994d199
·
verified ·
1 Parent(s): 6a6e465

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +58 -26
inference.py CHANGED
@@ -7,6 +7,7 @@ from glob import glob
7
  import torch, face_detection
8
  from models import Wav2Lip
9
  import platform
 
10
 
11
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
12
 
@@ -178,41 +179,72 @@ def load_model(path):
178
  model = model.to(device)
179
  return model.eval()
180
 
181
- def main():
182
- if not os.path.isfile(args.face):
183
- raise ValueError('--face argument must be a valid path to video/image file')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
- elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']:
186
- full_frames = [cv2.imread(args.face)]
187
- fps = args.fps
 
 
188
 
189
- else:
190
- video_stream = cv2.VideoCapture(args.face)
191
- fps = video_stream.get(cv2.CAP_PROP_FPS)
 
192
 
193
- print('Reading video frames...')
 
194
 
195
- full_frames = []
196
- while 1:
197
- still_reading, frame = video_stream.read()
198
- if not still_reading:
199
- video_stream.release()
200
- break
201
- if args.resize_factor > 1:
202
- frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor))
203
 
204
- if args.rotate:
205
- frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)
206
 
207
- y1, y2, x1, x2 = args.crop
208
- if x2 == -1: x2 = frame.shape[1]
209
- if y2 == -1: y2 = frame.shape[0]
210
 
211
- frame = frame[y1:y2, x1:x2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- full_frames.append(frame)
214
 
215
- print ("Number of frames available for inference: "+str(len(full_frames)))
216
 
217
  if not args.audio.endswith('.wav'):
218
  print('Extracting raw audio...')
 
7
  import torch, face_detection
8
  from models import Wav2Lip
9
  import platform
10
+ import ffmpeg
11
 
12
  parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')
13
 
 
179
  model = model.to(device)
180
  return model.eval()
181
 
182
+ def convert_video_to_h264(input_path, output_path="converted_video.mp4"):
183
+ """Convert AV1 or unsupported videos to H.264 using ffmpeg."""
184
+ try:
185
+ print(f"Converting {input_path} to {output_path}...")
186
+ subprocess.run([
187
+ "ffmpeg", "-y", "-i", input_path, "-c:v", "libx264", "-c:a", "aac", output_path
188
+ ], check=True)
189
+ print("Conversion successful.")
190
+ return output_path
191
+ except subprocess.CalledProcessError as e:
192
+ print(f"Error during video conversion: {e}")
193
+ raise
194
+
195
+ def load_video_frames(video_path):
196
+ """Load video frames from the given path."""
197
+ video_stream = cv2.VideoCapture(video_path)
198
+ if not video_stream.isOpened():
199
+ raise ValueError(f"Could not open video: {video_path}")
200
+
201
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
202
+ print(f"Reading video frames at {fps} FPS...")
203
 
204
+ full_frames = []
205
+ while True:
206
+ still_reading, frame = video_stream.read()
207
+ if not still_reading:
208
+ break
209
 
210
+ if args.resize_factor > 1:
211
+ frame = cv2.resize(
212
+ frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor)
213
+ )
214
 
215
+ if args.rotate:
216
+ frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
217
 
218
+ y1, y2, x1, x2 = args.crop
219
+ if x2 == -1: x2 = frame.shape[1]
220
+ if y2 == -1: y2 = frame.shape[0]
 
 
 
 
 
221
 
222
+ frame = frame[y1:y2, x1:x2]
223
+ full_frames.append(frame)
224
 
225
+ video_stream.release()
226
+ return full_frames, fps
 
227
 
228
+ def main():
229
+ if not os.path.isfile(args.face):
230
+ raise ValueError("--face argument must be a valid path to video/image file")
231
+
232
+ if args.face.split('.')[-1] in ['jpg', 'png', 'jpeg']:
233
+ full_frames = [cv2.imread(args.face)]
234
+ fps = args.fps
235
+
236
+ else:
237
+ # Try loading the video with OpenCV first
238
+ video_path = args.face
239
+ video_stream = cv2.VideoCapture(video_path)
240
+
241
+ if not video_stream.isOpened():
242
+ print("OpenCV failed to open video. Attempting ffmpeg conversion...")
243
+ video_path = convert_video_to_h264(args.face)
244
 
245
+ full_frames, fps = load_video_frames(video_path)
246
 
247
+ print(f"Loaded {len(full_frames)} frames.")
248
 
249
  if not args.audio.endswith('.wav'):
250
  print('Extracting raw audio...')