Phillnet-2 / ImageGen /validate_imagegen.py
ayjays132's picture
Upload 478 files
101858b verified
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from ImageGen import ImageGenPipeline
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir", default="ImageGen")
parser.add_argument("--device", default="cpu")
parser.add_argument("--generate", action="store_true")
args = parser.parse_args()
model_dir = Path(args.model_dir)
required = [
"adapter_model.pt",
"config.json",
"training_config.json",
"model_index.json",
"tokenizer/tokenizer.json",
"model/universal_hf_text_to_image_adapter.py",
]
missing = [item for item in required if not (model_dir / item).exists()]
if missing:
raise FileNotFoundError(f"Missing ImageGen files: {missing}")
with (model_dir / "model_index.json").open("r", encoding="utf-8") as f:
model_index = json.load(f)
with (model_dir / "training_config.json").open("r", encoding="utf-8") as f:
training_config = json.load(f)
pipe = ImageGenPipeline.from_pretrained(model_dir, device=args.device)
state = torch.load(model_dir / "adapter_model.pt", map_location="cpu")
model_keys = set(pipe.adapter.adapter_state_dict().keys())
weight_keys = set(state.keys())
missing_in_model = sorted(weight_keys - model_keys)
missing_in_weights = sorted(model_keys - weight_keys)
print("model_index_class", model_index.get("_class_name"))
print("global_step", training_config.get("global_step"))
print("weight_tensors", len(weight_keys))
print("adapter_tensors", len(model_keys))
print("missing_weight_keys_in_model", len(missing_in_model))
print("missing_model_keys_in_weights", len(missing_in_weights))
if missing_in_model:
print("first_missing_weight_keys_in_model", missing_in_model[:10])
if missing_in_weights:
print("first_missing_model_keys_in_weights", missing_in_weights[:10])
encoded = pipe._tokenize("a small neon geometric logo", max_length=32)
print("tokenized_shape", tuple(encoded["input_ids"].shape))
if args.generate:
out = pipe(
"a small neon geometric logo",
height=128,
width=128,
num_inference_steps=1,
output_type="pt",
)
tensor = out.tensors
print("generated_shape", tuple(tensor.shape))
print("generated_finite", bool(torch.isfinite(tensor).all()))
if missing_in_model:
raise RuntimeError("Some saved trained tensors are not represented by the current architecture.")
return 0
if __name__ == "__main__":
raise SystemExit(main())