AI_Avatar_Chat / scripts /inference.py
bravedims
Deploy OmniAvatar-14B with ElevenLabs TTS integration to Hugging Face Spaces
bd1f2b1
raw
history blame
2.86 kB
import argparse
import yaml
import torch
import os
import sys
from pathlib import Path
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="OmniAvatar-14B Inference")
parser.add_argument("--config", type=str, required=True, help="Path to config file")
parser.add_argument("--input_file", type=str, required=True, help="Path to input samples file")
parser.add_argument("--guidance_scale", type=float, default=5.0, help="Guidance scale")
parser.add_argument("--audio_scale", type=float, default=3.0, help="Audio guidance scale")
parser.add_argument("--num_steps", type=int, default=30, help="Number of inference steps")
parser.add_argument("--sp_size", type=int, default=1, help="Multi-GPU size")
parser.add_argument("--tea_cache_l1_thresh", type=float, default=None, help="TeaCache threshold")
return parser.parse_args()
def load_config(config_path):
with open(config_path, 'r') as f:
return yaml.safe_load(f)
def process_input_file(input_file):
"""Parse input file with format: prompt@@image_path@@audio_path"""
samples = []
with open(input_file, 'r') as f:
for line in f:
line = line.strip()
if line:
parts = line.split('@@')
if len(parts) >= 3:
prompt = parts[0]
image_path = parts[1] if parts[1] else None
audio_path = parts[2]
samples.append({
'prompt': prompt,
'image_path': image_path,
'audio_path': audio_path
})
return samples
def main():
args = parse_args()
# Load configuration
config = load_config(args.config)
# Process input samples
samples = process_input_file(args.input_file)
logger.info(f"Processing {len(samples)} samples")
# Create output directory
output_dir = Path(config['output']['output_dir'])
output_dir.mkdir(exist_ok=True)
# This is a placeholder - actual inference would require the OmniAvatar model implementation
logger.info("Note: This is a placeholder inference script.")
logger.info("Actual implementation would require:")
logger.info("1. Loading the OmniAvatar model")
logger.info("2. Processing audio with wav2vec2")
logger.info("3. Running video generation pipeline")
logger.info("4. Saving output videos")
for i, sample in enumerate(samples):
logger.info(f"Sample {i+1}: {sample['prompt']}")
logger.info(f" Audio: {sample['audio_path']}")
logger.info(f" Image: {sample['image_path']}")
logger.info("Inference completed successfully!")
if __name__ == "__main__":
main()