|
|
from transformers import PretrainedConfig, PreTrainedModel, Pipeline
|
|
|
import torch
|
|
|
|
|
|
from BeamDiffusionModel.beamInference import beam_inference
|
|
|
from BeamDiffusionModel.models.diffusionModel.StableDiffusion import StableDiffusion
|
|
|
from BeamDiffusionModel.models.diffusionModel.Flux import Flux
|
|
|
|
|
|
class BeamDiffusionConfig(PretrainedConfig):
|
|
|
model_type = "beam_diffusion"
|
|
|
def __init__(self, sd="SD-2.1",latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.sd_name = sd
|
|
|
self.sd = None
|
|
|
self.get_model(sd)
|
|
|
self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3]
|
|
|
self.n_seeds = n_seeds
|
|
|
self.seeds = seeds if seeds else []
|
|
|
self.steps_back = steps_back
|
|
|
self.beam_width = beam_width
|
|
|
self.window_size = window_size
|
|
|
self.use_rand = use_rand
|
|
|
|
|
|
def get_model(self, sd):
|
|
|
if self.sd_name == "flux":
|
|
|
self.sd = Flux()
|
|
|
elif self.sd_name == "SD-2.1":
|
|
|
self.sd = StableDiffusion()
|
|
|
|
|
|
import torch.nn as nn
|
|
|
from huggingface_hub import ModelHubMixin
|
|
|
|
|
|
class BeamDiffusionModel(PreTrainedModel, ModelHubMixin):
|
|
|
config_class = BeamDiffusionConfig
|
|
|
model_type = "beam_diffusion"
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.config = config
|
|
|
self.dummy_param = nn.Parameter(torch.zeros(1))
|
|
|
|
|
|
def forward(self, input_data):
|
|
|
images = beam_inference(
|
|
|
self.config.sd,
|
|
|
steps=input_data.get('steps', []),
|
|
|
latents_idx=self.config.latents_idx,
|
|
|
n_seeds=self.config.n_seeds,
|
|
|
seeds=self.config.seeds,
|
|
|
steps_back=self.config.steps_back,
|
|
|
beam_width=self.config.beam_width,
|
|
|
window_size=self.config.window_size,
|
|
|
use_rand=self.config.use_rand,
|
|
|
)
|
|
|
return {"images": images}
|
|
|
|
|
|
|
|
|
|
|
|
class BeamDiffusionPipeline(Pipeline, ModelHubMixin):
|
|
|
def __init__(self, model, tokenizer=None, device="cuda", framework="pt"):
|
|
|
super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework)
|
|
|
|
|
|
def __call__(self, inputs):
|
|
|
return self._forward(inputs)
|
|
|
|
|
|
def preprocess(self, inputs):
|
|
|
"""Converts raw input data into model-ready format."""
|
|
|
return inputs
|
|
|
|
|
|
def postprocess(self, model_outputs):
|
|
|
"""Processes model output into a user-friendly format."""
|
|
|
return model_outputs["images"]
|
|
|
|
|
|
def _sanitize_parameters(self, **kwargs):
|
|
|
"""Handles unused parameters gracefully."""
|
|
|
return {}, {}, {}
|
|
|
|
|
|
def _forward(self, model_inputs):
|
|
|
return self.model(model_inputs)
|
|
|
|