| | --- |
| | license: apple-amlr |
| | language: |
| | - en |
| | tags: |
| | - normalizing-flows |
| | - generative-models |
| | - art |
| | - autoregressive-models |
| | --- |
| | |
| | [STARFlow](https://huggingface.co/apple/starflow) T2I checkpoint, converted to safetensors. |
| | Intended to be used in [ComfyUI-STARFlow](https://github.com/RyukoMatoiFan/ComfyUI-STARFlow). |
| |
|
| | Converted with the following script |
| |
|
| | ``` |
| | import torch |
| | from safetensors.torch import save_file |
| | |
| | def main(src="starflow_3B_t2i_256x256.pth", dst="starflow_3B_t2i_256x256.safetensors"): |
| | obj = torch.load(src, map_location="cpu") |
| | |
| | if isinstance(obj, dict) and "state_dict" in obj: |
| | obj = obj["state_dict"] |
| | |
| | if not isinstance(obj, dict): |
| | raise TypeError(f"Expected a dict/state_dict, got: {type(obj)}") |
| | |
| | tensor_dict = {k: v for k, v in obj.items() if isinstance(k, str) and torch.is_tensor(v)} |
| | skipped = len(obj) - len(tensor_dict) |
| | |
| | if not tensor_dict: |
| | raise ValueError("No tensors found to save.") |
| | |
| | save_file(tensor_dict, dst) |
| | print(f"saved: {dst} (tensors: {len(tensor_dict)}, skipped non-tensors: {skipped})") |
| | |
| | if __name__ == "__main__": |
| | main() |
| | |
| | ``` |