File size: 2,801 Bytes
101858b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import torch

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from ImageGen import ImageGenPipeline


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-dir", default="ImageGen")
    parser.add_argument("--device", default="cpu")
    parser.add_argument("--generate", action="store_true")
    args = parser.parse_args()

    model_dir = Path(args.model_dir)
    required = [
        "adapter_model.pt",
        "config.json",
        "training_config.json",
        "model_index.json",
        "tokenizer/tokenizer.json",
        "model/universal_hf_text_to_image_adapter.py",
    ]
    missing = [item for item in required if not (model_dir / item).exists()]
    if missing:
        raise FileNotFoundError(f"Missing ImageGen files: {missing}")

    with (model_dir / "model_index.json").open("r", encoding="utf-8") as f:
        model_index = json.load(f)
    with (model_dir / "training_config.json").open("r", encoding="utf-8") as f:
        training_config = json.load(f)

    pipe = ImageGenPipeline.from_pretrained(model_dir, device=args.device)
    state = torch.load(model_dir / "adapter_model.pt", map_location="cpu")
    model_keys = set(pipe.adapter.adapter_state_dict().keys())
    weight_keys = set(state.keys())
    missing_in_model = sorted(weight_keys - model_keys)
    missing_in_weights = sorted(model_keys - weight_keys)

    print("model_index_class", model_index.get("_class_name"))
    print("global_step", training_config.get("global_step"))
    print("weight_tensors", len(weight_keys))
    print("adapter_tensors", len(model_keys))
    print("missing_weight_keys_in_model", len(missing_in_model))
    print("missing_model_keys_in_weights", len(missing_in_weights))
    if missing_in_model:
        print("first_missing_weight_keys_in_model", missing_in_model[:10])
    if missing_in_weights:
        print("first_missing_model_keys_in_weights", missing_in_weights[:10])

    encoded = pipe._tokenize("a small neon geometric logo", max_length=32)
    print("tokenized_shape", tuple(encoded["input_ids"].shape))

    if args.generate:
        out = pipe(
            "a small neon geometric logo",
            height=128,
            width=128,
            num_inference_steps=1,
            output_type="pt",
        )
        tensor = out.tensors
        print("generated_shape", tuple(tensor.shape))
        print("generated_finite", bool(torch.isfinite(tensor).all()))

    if missing_in_model:
        raise RuntimeError("Some saved trained tensors are not represented by the current architecture.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())