File size: 3,656 Bytes
bc275c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
860c112
bc275c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import annotations

import gc
import json
import random
import threading
from pathlib import Path
from typing import Dict, Tuple

import numpy as np
import torch
from PIL import Image

from .backends import FluxBilaBackend, Ip2pBilaBackend, release_cuda
from .image_io import blend_strength, prepare_image, tensor_to_pil
from .weights import configure_runtime_cache, require_paths, resolve_model_root


MANIFEST_PATH = Path(__file__).resolve().parents[1] / "model_manifest.json"


def load_manifest() -> Dict:
    with MANIFEST_PATH.open("r", encoding="utf-8") as handle:
        return json.load(handle)


def seed_everything(seed: int) -> None:
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed % (2**32 - 1))
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


class DemoManager:
    def __init__(self):
        configure_runtime_cache()
        self.manifest = load_manifest()
        self._backend = None
        self._backend_key = None
        self._lock = threading.Lock()

    @property
    def model_choices(self):
        return [
            (cfg["label"], key)
            for key, cfg in self.manifest["models"].items()
        ]

    @property
    def default_model(self):
        return self.manifest["default_model"]

    def _paths_for_model(self, model_key: str) -> Dict[str, Path]:
        model_cfg = self.manifest["models"][model_key]
        root = resolve_model_root(model_key, model_cfg)
        require_paths(root, model_cfg["weights"].values())
        return {name: root / rel for name, rel in model_cfg["weights"].items()}

    def _load_backend(self, model_key: str):
        if self._backend_key == model_key and self._backend is not None:
            return self._backend

        self._backend = None
        self._backend_key = None
        gc.collect()
        release_cuda()

        model_cfg = self.manifest["models"][model_key]
        paths = self._paths_for_model(model_key)
        if model_cfg["kind"] == "ip2p":
            backend = Ip2pBilaBackend(model_cfg, paths)
        elif model_cfg["kind"] == "flux":
            backend = FluxBilaBackend(model_cfg, paths)
        else:
            raise ValueError(f"Unknown model kind: {model_cfg['kind']}")

        self._backend = backend
        self._backend_key = model_key
        return backend

    def generate(
        self,
        image: Image.Image,
        instruction: str,
        model_key: str,
        seed: int,
        max_side: int,
        strength: float,
    ) -> Tuple[Image.Image, Image.Image, Image.Image, str]:
        if image is None:
            raise ValueError("Please upload an image.")
        instruction = (instruction or "").strip()
        if model_key not in self.manifest["models"]:
            raise ValueError(f"Unknown model: {model_key}")

        with self._lock:
            model_cfg = self.manifest["models"][model_key]
            model_size = int(model_cfg["config"].get("model_size", 512))
            prepared = prepare_image(image, max_side=max_side, model_size=model_size)
            backend = self._load_backend(model_key)
            seed_everything(seed)
            result = backend(
                prepared.model_tensor,
                [instruction],
                prepared.full_tensor,
            )
            edited_tensor = blend_strength(prepared.full_tensor, result["bila"], strength)
            edited = tensor_to_pil(edited_tensor)
            diff = tensor_to_pil(result["diff"])
            status = f"{model_cfg['label']} | seed={int(seed)}"
            return edited, diff, prepared.full_pil, status