File size: 3,907 Bytes
891e05c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import shutil
import subprocess
import tempfile
import hydra
import torch

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint, BasePredictionWriter
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.deepspeed import (
    convert_zero_checkpoint_to_fp32_state_dict,
)
from safetensors.torch import save_file


@rank_zero_only
def convert_deepspeed_checkpoint(
    cfg: DictConfig, checkpoint_callback: ModelCheckpoint, output_dir: str
):
    """
    Convert deepspeed checkpoint to fp32 safetensors format.
    All frozen parameters will be removed.
    """
    os.makedirs(output_dir, exist_ok=True)
    convert_zero_checkpoint_to_fp32_state_dict(
        checkpoint_callback.best_model_path,
        os.path.join(output_dir, "fp32_state_dict.pth"),
    )
    with torch.serialization.safe_globals([set]):
        ckpt = torch.load(
            os.path.join(output_dir, "fp32_state_dict.pth"),
            map_location="cpu",
            weights_only=True,
        )
    for param in list(ckpt["state_dict"].keys()):
        if getattr(cfg.model, "freeze_backbone", False) and param.startswith(
            "loupe.backbone"
        ):
            ckpt["state_dict"].pop(param)
        if getattr(cfg.model, "freeze_cls", False) and param.startswith(
            "loupe.classifier"
        ):
            ckpt["state_dict"].pop(param)
        if getattr(cfg.model, "freeze_seg", False) and param.startswith(
            "loupe.segmentor"
        ):
            ckpt["state_dict"].pop(param)
    save_file(ckpt["state_dict"], os.path.join(output_dir, "model.safetensors"))
    OmegaConf.save(config=cfg, f=os.path.join(output_dir, "config.yaml"))
    OmegaConf.save(
        config=hydra.core.hydra_config.HydraConfig.get().overrides.task,
        f=os.path.join(output_dir, "overrides.yaml"),
    )
    print(f"Model converted to FP32 and saved to {output_dir}.")

    os.remove(os.path.join(output_dir, "fp32_state_dict.pth"))
    shutil.rmtree(checkpoint_callback.best_model_path)


@rank_zero_only
def prepare_output_dir(pred_path, mask_dir):
    if os.path.isfile(pred_path):
        os.remove(pred_path)
    if os.path.isdir(mask_dir):
        print(f"Removing existing directory: {mask_dir}...")
        try:
            with tempfile.TemporaryDirectory() as empty_dir:
                result = subprocess.run(
                    ["rsync", "-a", "--delete", empty_dir + "/", mask_dir + "/"],
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    text=True,
                )
                if result.returncode != 0:
                    raise RuntimeError(f"rsync failed: {result.stderr}")
        except (FileNotFoundError, RuntimeError) as e:
            print(
                f"rsync not available or failed ({e}), overwriting previous results..."
            )
    os.makedirs(mask_dir, exist_ok=True)


class CustomWriter(BasePredictionWriter):
    def __init__(self, cfg: DictConfig, write_interval):
        super().__init__(write_interval)
        output_dir = cfg.stage.pred_output_dir
        self.mask_dir = os.path.join(output_dir, "masks")
        self.pred_path = os.path.join(output_dir, "predictions.txt")

        prepare_output_dir(self.pred_path, self.mask_dir)

    def write_on_batch_end(
        self,
        trainer,
        pl_module,
        prediction,
        batch_indices,
        batch,
        batch_idx,
        dataloader_idx,
    ):
        cls_probs, pred_masks = prediction["cls_probs"], prediction["pred_masks"]
        with open(self.pred_path, "a") as f:
            for name, cls_prob in zip(batch["name"], cls_probs):
                f.write(f"{name},{cls_prob:.4f}\n")

        for name, pred_mask in zip(batch["name"], pred_masks):
            pred_mask.save(os.path.join(self.mask_dir, name))