Spaces:
Paused
Paused
File size: 2,855 Bytes
bd1f2b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import argparse
import yaml
import torch
import os
import sys
from pathlib import Path
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="OmniAvatar-14B Inference")
parser.add_argument("--config", type=str, required=True, help="Path to config file")
parser.add_argument("--input_file", type=str, required=True, help="Path to input samples file")
parser.add_argument("--guidance_scale", type=float, default=5.0, help="Guidance scale")
parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio guidance scale")
parser.add_argument("--num_steps", type=int, default=30, help="Number of inference steps")
parser.add_argument("--sp_size", type=int, default=1, help="Multi-GPU size")
parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache threshold")
return parser.parse_args()
def load_config(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def process_input_file(input_file):
"""Parse input file with format: prompt@@image_path@@audio_path"""
samples = []
with open(input_file, 'r') as f:
for line in f:
line = line.strip()
if line:
parts = line.split('@@')
if len(parts) >= 3:
prompt = parts[0]
image_path = parts[1] if parts[1] else None
audio_path = parts[2]
samples.append({
'prompt': prompt,
'image_path': image_path,
'audio_path': audio_path
})
return samples
def main():
args = parse_args()
# Load configuration
config = load_config(args.config)
# Process input samples
samples = process_input_file(args.input_file)
logger.info(f"Processing {len(samples)} samples")
# Create output directory
output_dir = Path(config['output']['output_dir'])
output_dir.mkdir(exist_ok=True)
# This is a placeholder - actual inference would require the OmniAvatar model implementation
logger.info("Note: This is a placeholder inference script.")
logger.info("Actual implementation would require:")
logger.info("1. Loading the OmniAvatar model")
logger.info("2. Processing audio with wav2vec2")
logger.info("3. Running video generation pipeline")
logger.info("4. Saving output videos")
for i, sample in enumerate(samples):
logger.info(f"Sample {i+1}: {sample['prompt']}")
logger.info(f" Audio: {sample['audio_path']}")
logger.info(f" Image: {sample['image_path']}")
logger.info("Inference completed successfully!")
if __name__ == "__main__":
main()
|