from __future__ import annotations import argparse import json import sys from pathlib import Path import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from ImageGen import ImageGenPipeline def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--model-dir", default="ImageGen") parser.add_argument("--device", default="cpu") parser.add_argument("--generate", action="store_true") args = parser.parse_args() model_dir = Path(args.model_dir) required = [ "adapter_model.pt", "config.json", "training_config.json", "model_index.json", "tokenizer/tokenizer.json", "model/universal_hf_text_to_image_adapter.py", ] missing = [item for item in required if not (model_dir / item).exists()] if missing: raise FileNotFoundError(f"Missing ImageGen files: {missing}") with (model_dir / "model_index.json").open("r", encoding="utf-8") as f: model_index = json.load(f) with (model_dir / "training_config.json").open("r", encoding="utf-8") as f: training_config = json.load(f) pipe = ImageGenPipeline.from_pretrained(model_dir, device=args.device) state = torch.load(model_dir / "adapter_model.pt", map_location="cpu") model_keys = set(pipe.adapter.adapter_state_dict().keys()) weight_keys = set(state.keys()) missing_in_model = sorted(weight_keys - model_keys) missing_in_weights = sorted(model_keys - weight_keys) print("model_index_class", model_index.get("_class_name")) print("global_step", training_config.get("global_step")) print("weight_tensors", len(weight_keys)) print("adapter_tensors", len(model_keys)) print("missing_weight_keys_in_model", len(missing_in_model)) print("missing_model_keys_in_weights", len(missing_in_weights)) if missing_in_model: print("first_missing_weight_keys_in_model", missing_in_model[:10]) if missing_in_weights: print("first_missing_model_keys_in_weights", missing_in_weights[:10]) encoded = pipe._tokenize("a small neon geometric logo", max_length=32) print("tokenized_shape", tuple(encoded["input_ids"].shape)) if args.generate: out = pipe( "a small neon geometric logo", height=128, width=128, num_inference_steps=1, output_type="pt", ) tensor = out.tensors print("generated_shape", tuple(tensor.shape)) print("generated_finite", bool(torch.isfinite(tensor).all())) if missing_in_model: raise RuntimeError("Some saved trained tensors are not represented by the current architecture.") return 0 if __name__ == "__main__": raise SystemExit(main())