Janus-first2last / api_ddpm.py
90879c's picture
Upload 14 files
9e1916e verified
Raw
History Blame Contribute Delete
2.67 kB
"""CLI and helper API for the Hugging Face Janus-first2last release."""
import argparse
import os
import numpy as np
from transformers import AutoModel
def load_density(path, field="av_density"):
if path.endswith(".npz"):
return np.load(path)[field].astype(np.float32)
if path.endswith(".npy"):
return np.load(path).astype(np.float32)
raise ValueError(f"Unsupported input format: {path} (use .npz or .npy)")
def save_prediction(path, prediction, std=None):
output_dir = os.path.dirname(path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
if path.endswith(".npz"):
payload = {"av_density": prediction}
if std is not None:
payload.update({"mean": prediction, "std": std})
np.savez(path, **payload)
elif path.endswith(".npy"):
np.save(path, prediction)
else:
raise ValueError(f"Unsupported output format: {path} (use .npz or .npy)")
def load_model(model_id_or_path=".", device=None):
model = AutoModel.from_pretrained(model_id_or_path, trust_remote_code=True)
if device is not None:
model = model.to(device)
return model.eval()
def main():
parser = argparse.ArgumentParser(description="Run Janus-first2last DDIM inference.")
parser.add_argument("--model", default=".", help="HF repo id or local model directory")
parser.add_argument("--input", required=True, help="Input .npz or .npy file")
parser.add_argument("--output", required=True, help="Output .npz or .npy file")
parser.add_argument("--field", default="av_density", help="Field name in input NPZ")
parser.add_argument("--device", default=None, help="cuda, cpu, mps, or default auto device")
parser.add_argument("--num_steps", type=int, default=50, help="DDIM sampling steps")
parser.add_argument("--eta", type=float, default=1.0, help="DDIM stochasticity")
parser.add_argument("--seed", type=int, default=None, help="Random seed for one sample")
parser.add_argument("--n_samples", type=int, default=1, help="Number of ensemble samples")
args = parser.parse_args()
density = load_density(args.input, args.field)
model = load_model(args.model, args.device)
if args.n_samples == 1:
prediction = model.predict(density, num_steps=args.num_steps, eta=args.eta, seed=args.seed)
save_prediction(args.output, prediction)
else:
mean, std = model.predict_ensemble(
density, n_samples=args.n_samples, num_steps=args.num_steps, eta=args.eta
)
save_prediction(args.output, mean, std=std)
print(f"Saved: {args.output}")
if __name__ == "__main__":
main()