In [5]:
import torch
from datasets import load_from_disk
from diffusers import AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import math
from tqdm.auto import tqdm
import os
import random

import numpy as np
import random

import torch

from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "AiArtLab/sdxs3d"  # Replace to the model you would like to use

# Initialize models and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32


class SimpleDiffusionPipeline(DiffusionPipeline):
    def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
        super().__init__()
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt,
        negative_prompt=None,
        height=512,
        width=512,
        num_inference_steps=50,
        guidance_scale=4.0,
        generator=None,
        **kwargs,
    ):
        batch_size = len(prompt) if isinstance(prompt, list) else 1

        # 1. Токенизация
        toks = self.tokenizer(
            prompt,
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)

        outs = self.text_encoder(**toks)
        text_emb = outs.last_hidden_state[:, -1].unsqueeze(1)  # твой last_token_pool

        if negative_prompt is not None:
            neg_toks = self.tokenizer(
                negative_prompt,
                padding="max_length",
                truncation=True,
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            neg_outs = self.text_encoder(**neg_toks)
            neg_emb = neg_outs.last_hidden_state[:, -1].unsqueeze(1)
        else:
            neg_emb = torch.zeros_like(text_emb)

        # guidance
        if guidance_scale != 1.0:
            text_emb = torch.cat([neg_emb, text_emb])

        # 2. Латенты
        print("VAE scaling_factor =", vae.config.scaling_factor)

        latents = torch.randn(
            (batch_size, self.unet.config.in_channels, int(height / 8), int(width / 8),),
            device=self.device,
            dtype=torch.float16,
            generator=generator,
        )

        self.scheduler.set_timesteps(num_inference_steps, device=self.device)

        # 3. Диффузия
        for t in self.scheduler.timesteps:
            latent_input = torch.cat([latents, latents]) if guidance_scale != 1.0 else latents
            flow = self.unet(latent_input, t, encoder_hidden_states=text_emb).sample

            if guidance_scale != 1.0:
                flow_uncond, flow_cond = flow.chunk(2)
                flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)

            latents = self.scheduler.step(flow, t, latents).prev_sample

        # 4. Декод
        #latents = latents / self.vae.config.scaling_factor
        images = self.vae.decode(latents).sample
        images = (images / 2 + 0.5).clamp(0, 1)

        return images

def get_random_samples(dataset_path, num_samples=20):
    """
    Возвращает случайные тексты, ширину и высоту из датасета.
    """
    # Загружаем датасет
    dataset = load_from_disk(dataset_path)

    # Удаление записей с "vectorstock"
    filtered_indices = [i for i, text in enumerate(dataset['text']) if 'vector' not in text.lower()]
    dataset = dataset.select(filtered_indices)

    # Выбираем случайные индексы
    random_indices = random.sample(range(len(dataset)), num_samples)

    # Извлекаем тексты, ширину и высоту
    samples = [
        {
            "text": dataset[i]['text'],
            "width": dataset[i]['width'],
            "height": dataset[i]['height']
        }
        for i in random_indices
    ]

    return samples

vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae", torch_dtype=dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer")
text_encoder = AutoModel.from_pretrained(model_repo_id, subfolder="text_encoder", torch_dtype=dtype).to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo_id, subfolder="scheduler")

pipe = SimpleDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler,
).to(device)

def generate_and_save_images(dataset_path, output_folder="samples", project_name="sdxs"):
    # Load random samples
    samples = get_random_samples(dataset_path,num_samples=100)

    os.makedirs(output_folder, exist_ok=True)
    generator = torch.Generator(device=device).manual_seed(42) 

    for idx, sample in enumerate(samples):
        prompt = sample["text"]
        negative_prompt = "bad quality, low quality, low resolution"
        height, width = sample["height"], sample["width"]
        print(height, width,prompt)
        images_tensor = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            generator = generator
        )  # [B, C, H, W]

        # Конвертация в numpy для Gradio
        image = images_tensor[0].cpu().permute(1, 2, 0).numpy()
        image = (image * 255).astype(np.uint8)
        image = Image.fromarray(image)
        image.save(f"{output_folder}/{project_name}_{idx}.jpg")

    print("Images generated and saved to:", output_folder)

# Example usage
dataset_path = "/workspace/sdxs3d/datasets/mjnj"
generate_and_save_images(dataset_path)


Loading dataset from disk:   0%|          | 0/128 [00:00<?, ?it/s]

384 384 An ancient-looking book cover adorned with intricate symbols, a pentagram, and a crescent moon, evoking mysticism and the occult, with worn edges suggesting a well-read or aged tome.
VAE scaling_factor = 1.0
256 384 A group of four men with long hair and tattoos are standing on a rooftop, laughing and looking at something with expressions of shock and excitement. They are dressed in casual clothing, with one man wearing a black t-shirt and another in a blue shirt. The men are positioned in the foreground, with the city skyline visible in the background. The lighting suggests it is either early morning or late afternoon, and the atmosphere appears to be one of camaraderie and shared joy.
VAE scaling_factor = 1.0
384 256 A sleek, black Audi sports car is on display in a modern showroom with a white floor and a blurred background, showcasing the vehicle's futuristic design and luxurious interior.
VAE scaling_factor = 1.0
192 384 In a bustling tavern, two individuals are sharing a 