diff --git "a/src/sample-Copy1.ipynb" "b/src/sample-Copy1.ipynb" new file mode 100644--- /dev/null +++ "b/src/sample-Copy1.ipynb" @@ -0,0 +1,685 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "18818aae-6d2e-40e9-bdc6-2e5900f08603", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f5c73ff7873c4f0d9d654f8972d1b77c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0%| | 0.00/726 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "# ====== Помощники ======\n", + "def last_token_pool(last_hidden_states: torch.Tensor,\n", + " attention_mask: torch.Tensor) -> torch.Tensor:\n", + " # Определяем, есть ли left padding\n", + " left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])\n", + " if left_padding:\n", + " return last_hidden_states[:, -1]\n", + " else:\n", + " sequence_lengths = attention_mask.sum(dim=1) - 1\n", + " batch_size = last_hidden_states.shape[0]\n", + " return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]\n", + "\n", + "def encode_texts(texts, max_length=150):\n", + " with torch.inference_mode():\n", + " toks = tokenizer(\n", + " texts, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=max_length\n", + " ).to(device)\n", + " #outs = text_model(**toks)\n", + " #emb = last_token_pool(outs.last_hidden_state, toks[\"attention_mask\"])\n", + " \n", + " # Добавляем размерность sequence_length для совместимости с cross-attention\n", + " # Превращаем (batch_size, hidden_dim) в (batch_size, 1, hidden_dim)\n", + " #emb = emb.unsqueeze(1)\n", + "\n", + " outputs = model.model(**toks, output_hidden_states=True)\n", + "\n", + " # Берем последний слой (эмбеддинги всех токенов)\n", + " hidden_states = outputs.hidden_states[-1] # [B, L, D]\n", + " #hidden_states = F.normalize(hidden_states, p=2, dim=-1)\n", + " print(hidden_states.shape)\n", + " return hidden_states\n", + "\n", + "from diffusers import FlowMatchEulerDiscreteScheduler\n", + "\n", + "@torch.no_grad()\n", + "def generate_images(\n", + " prompts,\n", + " neg_prompts=None,\n", + " width=384,\n", + " height=384,\n", + " num_inference_steps=40,\n", + " guidance_scale=1.0,\n", + " generator=None,\n", + "):\n", + " \"\"\"\n", + " Генерация изображений с использованием FlowMatchEulerDiscreteScheduler из diffusers.\n", + " Всё вычисление происходит в half-precision.\n", + " \"\"\"\n", + " vae_scale_factor = 8\n", + " latents_shape = (\n", + " len(prompts),\n", + " unet.config.in_channels,\n", + " height // vae_scale_factor,\n", + " width // vae_scale_factor,\n", + " )\n", + "\n", + " # Инициализация scheduler\n", + "\n", + " scheduler.set_timesteps(num_inference_steps, device=device)\n", + "\n", + " # Инициализация латентов\n", + " latents = torch.randn(\n", + " latents_shape,\n", + " device=device,\n", + " dtype=torch.float16,\n", + " generator=generator,\n", + " )\n", + "\n", + " # Эмбеддинги промптов\n", + " text_emb = encode_texts(prompts).to(dtype=torch.float16, device=device)\n", + " if neg_prompts is None:\n", + " neg_prompts = [\"\"] * len(prompts)\n", + " uncond_emb = encode_texts(neg_prompts).to(dtype=torch.float16, device=device)\n", + "\n", + " # Подготовка латентов и эмбеддингов для guidance\n", + " if guidance_scale != 1.0:\n", + " latent_input = torch.cat([latents, latents])\n", + " text_emb_input = torch.cat([uncond_emb, text_emb])\n", + " else:\n", + " latent_input = latents\n", + " text_emb_input = text_emb\n", + "\n", + " # Интеграция по таймстепам с использованием scheduler\n", + " for t in scheduler.timesteps:\n", + " #print(t_batch)\n", + " flow = unet(\n", + " latent_input.half(),\n", + " t,\n", + " encoder_hidden_states=text_emb_input.half(),\n", + " ).sample.half()\n", + "\n", + " if guidance_scale != 1.0:\n", + " flow_uncond, flow_cond = flow.chunk(2)\n", + " flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)\n", + "\n", + " # Обновление латентов\n", + " latents = scheduler.step(flow, t, latents).prev_sample\n", + "\n", + " # Обновление latent_input для следующего шага\n", + " if guidance_scale != 1.0:\n", + " latent_input = torch.cat([latents, latents])\n", + "\n", + " # Декодирование латентов в изображения\n", + " latents_for_vae = latents / vae.config.scaling_factor\n", + " images = vae.decode(latents_for_vae.half()).sample.half()\n", + "\n", + " # Конвертация в PIL\n", + " images = (images.float() / 2 + 0.5).clamp(0, 1)\n", + " images = images.cpu().permute(0, 2, 3, 1).numpy()\n", + " pil_images = []\n", + " for img in images:\n", + " img = (img * 255).round().astype(\"uint8\")\n", + " pil_images.append(Image.fromarray(img))\n", + "\n", + " return pil_images\n", + "\n", + "def display_grid(images, cols=2):\n", + " rows = math.ceil(len(images) / cols)\n", + " w, h = images[0].size\n", + " grid = Image.new(\"RGB\", (cols * w, rows * h))\n", + " for i, img in enumerate(images):\n", + " grid.paste(img, (i % cols * w, i // cols * h))\n", + " return grid\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "\n", + "def display_image_grid(images, prompts, cols=4, save_path=None):\n", + " \"\"\"Отображение грида изображений с подписями и возможностью сохранения\"\"\"\n", + " n = len(images)\n", + " rows = math.ceil(n / cols)\n", + "\n", + " # Создаем фигуру\n", + " fig = plt.figure(figsize=(cols * 4, rows * 4))\n", + "\n", + " # Отображаем изображения\n", + " for i, (img, prompt) in enumerate(zip(images, prompts)):\n", + " ax = fig.add_subplot(rows, cols, i + 1)\n", + " ax.imshow(img)\n", + "\n", + " # Обрезаем длинные промпты\n", + " truncated_prompt = prompt[:60] + \" \" if len(prompt) > 60 else prompt\n", + "\n", + " # Разделяем промпт на две строки\n", + " lines = truncated_prompt.split(' ')\n", + " half = len(lines) // 2\n", + " line1 = ' '.join(lines[:half])\n", + " line2 = ' '.join(lines[half:])\n", + "\n", + " # Устанавливаем заголовок с двумя строками и увеличенным размером шрифта\n", + " ax.set_title(f'{line1}\\n{line2}', fontsize=12, wrap=True)\n", + " ax.axis('off')\n", + "\n", + " # Улучшаем расположение\n", + " plt.tight_layout(pad=1.5)\n", + "\n", + " # Сохраняем если нужно\n", + " if save_path:\n", + " plt.savefig(save_path, bbox_inches='tight', format='jpeg', dpi=500)\n", + "\n", + " plt.show()\n", + "\n", + "# ====== Пример использования ======\n", + "prompts = [\n", + " \"In the center of the image, a young adult with long silver hair and striking blue eyes is the focal point. The individual is squatting down, exuding a cool vibe, and is actively engaged in a performance. They are holding a microphone in their right hand, suggesting they are either singing or rapping. The person is dressed in a casual yet stylish outfit, featuring a black jacket, black baseball cap, and black thigh-high boots. The background is a dark, indoor setting with a purple hue, and there are beams of light shining down, adding a dramatic effect to the scene. The person's expression is one of concentration, indicating their focus on the performance.\",\n", + " \"A young woman with long, straight black hair and a bright smile stands on a balcony, wearing a white shirt and a blue sailor collar, with a picturesque view of a mountain range and a clear blue sky in the background.\",\n", + " 'аниме девушка, waifu, يبتسم جنسيا , sur le fond de la tour Eiffel'\n", + " ,\"A hauntingly beautiful ethereal figure with incandescent skin stands in a twilight forest, her form partially decaying yet radiating an inner light. Her silhouette blends elements of Art Nouveau elegance and cybernetic futurism, evoking the styles of Sorayama and Beksinski. Warm hues of sunset glow through the misty canopy, casting a golden light on the intricate details of her decaying flesh and metallic enhancements. In the background, surreal, ghostly trees with vibrant autumn leaves create a melancholic atmosphere, reminiscent of the works by Rockwell and Parrish.\"\n", + " ,\"1girl, small breast, red dress, sea\"\n", + " ,\"девушка в красном платье на фоне моря\"\n", + " \"1girl, solo, animal ears, bow, teeth, jacket, tail, open mouth, brown hair, orange background, bowtie, orange nails, simple background, cat ears, orange eyes, blue bow, animal ear fluff, cat tail, looking at viewer, upper body, shirt, school uniform, hood, striped bow, striped, white shirt, black jacket, blue bowtie, fingernails, long sleeves, cat girl, bangs, fangs, collared shirt, striped bowtie, short hair, tongue, hoodie, sharp teeth, facial mark, claw pose\"\n", + " ,\"нарядная новогодняя елка, красивые игрушки, звезда сверху, огоньки, на тёмном фоне\"\n", + " ,\"In the center of a dark, smoky background, a figure clad in a vibrant red bodysuit stands out. The suit is adorned with intricate designs and armor plating, giving it a formidable appearance. The helmet, matching the suit, features a visor with glowing red eyes, adding to the mysterious aura of the character.\"\n", + " ,\"A young woman with striking features. her hair, a mix of black and gold, floats around her head, adding a sense of movement to the scene. her eyes, a vibrant yellow, are half-closed, giving her a contemplative expression. she is dressed in a blue sweater, which contrasts with the black background. the overall composition of the image is simple yet striking, with the womans profile taking center stage\"\n", + " ,'A fluffy domestic cat with piercing green eyes sits attentively in a sunlit room filled natural light streaming through large windows, its soft fur reflecting warm hues of orange from the golden glow casting across its sleek body and delicate features'\n", + " ,\"A black-and-white photo of a fierce woman with a high ponytail wearing a spiked iron mask. The mask's sharp, metallic spikes add a menacing aura, contrasting with her soft yet intense expression. Her skin is marked with dirt, showcasing the aftermath of her battles, but her eyes reveal determination and strength. The spiked mask captures her mysterious and intimidating nature, adding to the intensity of the moment.\"\n", + " ,\"A close-up image of an astronaut's helmet with a frosted and opaque visor. The visor reflects the cold, frozen texture of space. Resting on the surface of the visor is a butterfly with vibrant, intricately patterned wings. The contrast between the delicate natural beauty of the butterfly and the cold, industrial helmet creates a striking image. The butterfly adds a touch of fragility and life to the otherwise harsh and unfeeling setting. The faint glow of distant stars can be seen through the frost, further enhancing the surreal atmosphere.\"\n", + " ,\"A watercolor painting of a knight standing tall in a field of wildflowers. The knight's armor is a mixture of soft greys and silvers, and he holds a large red rose. The background is a hazy wash of pale blue skies and distant mountains. The knight's strong yet serene posture creates an intriguing contrast between the image of a warrior and the delicate beauty of the flowers.\"\n", + " ,\"A photo of a rustic lantern glowing warmly in the middle of a snow-covered forest trail. The trail is lined with tall, snow-covered trees with faint branches. Soft falling snowflakes are visible, creating a serene atmosphere. The background is dark and mysterious. The overall image has a sense of solitude and magic.\"\n", + " ,\"Ein junges Mädchen mit langen braunen Haaren und braunen Augen steht an einer Backsteinwand, trägt ein weißes Hemd mit einem schwarzen Matrosenkragen und einen schwarzen Faltenrock. Sie lächelt und schaut direkt zum Betrachter, während Sonnenlicht durch die grünen Ranken hinter ihr fällt.\"\n", + " ,\"A 3D render of a hyper-realistic digital illustration of a surreal, moonlit scene. A silhouette of a mannequin man is climbing a delicate ladder suspended in an infinite gray twilight. The ladder is attached to a cloud. The background is a vast, starry sky with a crescent moon. The malemannequin is reaching for a glowing star.\"\n", + " ,\"ариец в имперских доспехах будущего \"\n", + " ,\"Эльфиечка Дня желает всем доброго утра и хорошего настроения!\"\n", + " ,\"a young anime girl with long, dark hair styled in two ponytails. She is wearing a white collared shirt and a green bow tie. In the background, there are two additional characters, one of whom is holding a hairbrush and the other is looking at the girl. The girl's expression appears to be one of sadness or disappointment. Pentagon Pixiv\"\n", + " ,\"девушка в толстовке , без капюшона , длинные черные штаны , карие глаза, светлые волосы, anime, bloody theme \"\n", + " ,\"ржавый краб, логотип, очень ржавый, очень краб\"\n", + " ,\"an anime-style illustration of a young girl with blue eyes and white hair, wearing a green dress and a black bow in her hair. She is sitting on a wooden bench, with a full moon in the background and a tree in the distance. The girl is holding a small white object in her hand, and there are two more white objects in the background. Konpaku Youmu Myon Sazanami Mio.jpg, Touhou\"\n", + " ,\"A mesmerizing, hyper-realistic tattoo design on a young woman's back, depicting an alien creature seemingly breaking out of her skin. The tattoo is highly detailed, with the alien's body, limbs, and facial features rendered in intricate detail. The woman's skin appears to have a translucent quality, revealing the intricate anatomy of the alien creature beneath. The overall mood of the image is eerie yet captivating, with the alien's sharp, glowing eyes piercing through the darkness.\"\n", + "]\n", + "\n", + "prompts = [\"cat\"]#,\"dog\"]\n", + "neg_prompts = [\"bad quality, low quality, low resolution\"] * len(prompts)\n", + "#neg_prompts = [\"\"] * len(prompts)\n", + "\n", + "# Исправленный вызов функции с правильными параметрами\n", + "images = generate_images(\n", + " prompts, \n", + " neg_prompts=neg_prompts, \n", + " guidance_scale=4.0, \n", + " num_inference_steps=50, \n", + " height=640, \n", + " width=576,\n", + " generator = torch.Generator(device=device).manual_seed(431)\n", + ")\n", + "\n", + "grid = display_image_grid(images,prompts, cols=3, save_path=\"result_grid2.png\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f4614e4b-2abd-453a-9703-65697c1d56bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ok\n" + ] + } + ], + "source": [ + "import torch\n", + "from diffusers import AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler\n", + "from transformers import AutoTokenizer, AutoModel\n", + "from PIL import Image\n", + "import math\n", + "\n", + "# ====== Настройки ======\n", + "dtype = torch.float16\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "unet = UNet2DConditionModel.from_pretrained( \"/workspace/sdxs3d\"#\"AiArtLab/sdxs3d\"\n", + " , subfolder=\"unet\", torch_dtype=dtype\n", + ").to(device).eval()\n", + "\n", + "unet.save_pretrained(\"/workspace/sdxs3d/unet\")\n", + "print(\"ok\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b08fbf66-8bd1-4a20-8715-0e748a07a932", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'gradio'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgradio\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mgr\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mrandom\u001b[39;00m\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'gradio'" + ] + } + ], + "source": [ + "import gradio as gr\n", + "import numpy as np\n", + "import random\n", + "\n", + "import spaces #[uncomment to use ZeroGPU]\n", + "import torch\n", + "\n", + "from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler\n", + "from transformers import AutoTokenizer, AutoModel\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "model_repo_id = \"AiArtLab/sdxs3d\" # Replace to the model you would like to use\n", + "\n", + "if torch.cuda.is_available():\n", + " dtype = torch.float16\n", + "else:\n", + " dtype = torch.float32\n", + "\n", + "\n", + "class SimpleDiffusionPipeline(DiffusionPipeline):\n", + " def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):\n", + " super().__init__()\n", + " self.register_modules(\n", + " vae=vae,\n", + " text_encoder=text_encoder,\n", + " tokenizer=tokenizer,\n", + " unet=unet,\n", + " scheduler=scheduler,\n", + " )\n", + "\n", + " @torch.no_grad()\n", + " def __call__(\n", + " self,\n", + " prompt,\n", + " negative_prompt=None,\n", + " height=512,\n", + " width=512,\n", + " num_inference_steps=50,\n", + " guidance_scale=4.0,\n", + " generator=None,\n", + " **kwargs,\n", + " ):\n", + " batch_size = len(prompt) if isinstance(prompt, list) else 1\n", + "\n", + " # 1. Токенизация\n", + " toks = self.tokenizer(\n", + " prompt,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " max_length=512,\n", + " return_tensors=\"pt\"\n", + " ).to(self.device)\n", + "\n", + " outs = self.text_encoder(**toks)\n", + " text_emb = outs.last_hidden_state[:, -1].unsqueeze(1) # твой last_token_pool\n", + "\n", + " if negative_prompt is not None:\n", + " neg_toks = self.tokenizer(\n", + " negative_prompt,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " max_length=512,\n", + " return_tensors=\"pt\"\n", + " ).to(self.device)\n", + " neg_outs = self.text_encoder(**neg_toks)\n", + " neg_emb = neg_outs.last_hidden_state[:, -1].unsqueeze(1)\n", + " else:\n", + " neg_emb = torch.zeros_like(text_emb)\n", + "\n", + " # guidance\n", + " if guidance_scale != 1.0:\n", + " text_emb = torch.cat([neg_emb, text_emb])\n", + "\n", + " # 2. Латенты\n", + " latents = torch.randn(\n", + " (batch_size, self.unet.config.in_channels, height // self.vae.config.scaling_factor, width // self.vae.config.scaling_factor),\n", + " device=self.device,\n", + " dtype=torch.float16,\n", + " generator=generator,\n", + " )\n", + "\n", + " self.scheduler.set_timesteps(num_inference_steps, device=self.device)\n", + "\n", + " # 3. Диффузия\n", + " for t in self.scheduler.timesteps:\n", + " latent_input = torch.cat([latents, latents]) if guidance_scale != 1.0 else latents\n", + " flow = self.unet(latent_input, t, encoder_hidden_states=text_emb).sample\n", + "\n", + " if guidance_scale != 1.0:\n", + " flow_uncond, flow_cond = flow.chunk(2)\n", + " flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)\n", + "\n", + " latents = self.scheduler.step(flow, t, latents).prev_sample\n", + "\n", + " # 4. Декод\n", + " latents = latents / self.vae.config.scaling_factor\n", + " images = self.vae.decode(latents).sample\n", + " images = (images / 2 + 0.5).clamp(0, 1)\n", + "\n", + " return images\n", + "\n", + "\n", + "vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder=\"vae\", torch_dtype=dtype).to(device)\n", + "unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder=\"unet\", torch_dtype=dtype).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_repo_id, subfolder=\"tokenizer\")\n", + "text_encoder = AutoModel.from_pretrained(model_repo_id, subfolder=\"text_encoder\", torch_dtype=dtype).to(device)\n", + "scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo_id, subfolder=\"scheduler\")\n", + "\n", + "pipe = SimpleDiffusionPipeline(\n", + " vae=vae,\n", + " text_encoder=text_encoder,\n", + " tokenizer=tokenizer,\n", + " unet=unet,\n", + " scheduler=scheduler,\n", + ").to(device)\n", + "\n", + "\n", + "MAX_SEED = np.iinfo(np.int32).max\n", + "MAX_IMAGE_SIZE = 384\n", + "\n", + "\n", + "@spaces.GPU #[uncomment to use ZeroGPU]\n", + "def infer(\n", + " prompt,\n", + " negative_prompt,\n", + " seed,\n", + " randomize_seed,\n", + " width,\n", + " height,\n", + " guidance_scale,\n", + " num_inference_steps,\n", + " progress=gr.Progress(track_tqdm=True),\n", + "):\n", + " if randomize_seed:\n", + " seed = random.randint(0, MAX_SEED)\n", + "\n", + " generator = torch.Generator(device=device).manual_seed(seed) # ← используйте seed, а не 42!\n", + "\n", + " # Генерация\n", + " images_tensor = pipe(\n", + " prompt=prompt,\n", + " negative_prompt=negative_prompt,\n", + " guidance_scale=guidance_scale,\n", + " num_inference_steps=num_inference_steps,\n", + " width=width,\n", + " height=height,\n", + " generator=generator,\n", + " ) # [B, C, H, W]\n", + "\n", + " # Конвертация в numpy для Gradio\n", + " image = images_tensor[0].cpu().permute(1, 2, 0).numpy()\n", + " image = (image * 255).astype(np.uint8)\n", + "\n", + " return image, seed\n", + "\n", + "\n", + "examples = [\n", + " \"A delicious ceviche cheesecake slice\",\n", + " \"ариец в имперских доспехах будущего\",\n", + " \"A close-up image of an astronaut's helmet with a frosted and opaque visor. The visor reflects the cold, frozen texture of space. Resting on the surface of the visor is a butterfly with vibrant, intricately patterned wings. The contrast between the delicate natural beauty of the butterfly and the cold, industrial helmet creates a striking image. The butterfly adds a touch of fragility and life to the otherwise harsh and unfeeling setting. The faint glow of distant stars can be seen through the frost, further enhancing the surreal atmosphere.\", \n", + "]\n", + "\n", + "css = \"\"\"\n", + "#col-container {\n", + " margin: 0 auto;\n", + " max-width: 640px;\n", + "}\n", + "\"\"\"\n", + "\n", + "with gr.Blocks(css=css) as demo:\n", + " with gr.Column(elem_id=\"col-container\"):\n", + " gr.Markdown(\" # Text-to-Image Gradio Template\")\n", + "\n", + " with gr.Row():\n", + " prompt = gr.Text(\n", + " label=\"Prompt\",\n", + " show_label=False,\n", + " max_lines=1,\n", + " placeholder=\"Enter your prompt\",\n", + " container=False,\n", + " )\n", + "\n", + " run_button = gr.Button(\"Run\", scale=0, variant=\"primary\")\n", + "\n", + " result = gr.Image(label=\"Result\", show_label=False)\n", + "\n", + " with gr.Accordion(\"Advanced Settings\", open=False):\n", + " negative_prompt = gr.Text(\n", + " label=\"Negative prompt\",\n", + " max_lines=1,\n", + " placeholder=\"Enter a negative prompt\",\n", + " visible=True,\n", + " value =\"low quality\"\n", + " )\n", + "\n", + " seed = gr.Slider(\n", + " label=\"Seed\",\n", + " minimum=0,\n", + " maximum=MAX_SEED,\n", + " step=1,\n", + " value=0,\n", + " )\n", + "\n", + " randomize_seed = gr.Checkbox(label=\"Randomize seed\", value=True)\n", + "\n", + " with gr.Row():\n", + " width = gr.Slider(\n", + " label=\"Width\",\n", + " minimum=192,\n", + " maximum=MAX_IMAGE_SIZE,\n", + " step=64,\n", + " value=256, # Replace with defaults that work for your model\n", + " )\n", + "\n", + " height = gr.Slider(\n", + " label=\"Height\",\n", + " minimum=192,\n", + " maximum=MAX_IMAGE_SIZE,\n", + " step=64,\n", + " value=384, # Replace with defaults that work for your model\n", + " )\n", + "\n", + " with gr.Row():\n", + " guidance_scale = gr.Slider(\n", + " label=\"Guidance scale\",\n", + " minimum=0.0,\n", + " maximum=10.0,\n", + " step=0.1,\n", + " value=4.0, # Replace with defaults that work for your model\n", + " )\n", + "\n", + " num_inference_steps = gr.Slider(\n", + " label=\"Number of inference steps\",\n", + " minimum=1,\n", + " maximum=50,\n", + " step=1,\n", + " value=40, # Replace with defaults that work for your model\n", + " )\n", + "\n", + " gr.Examples(examples=examples, inputs=[prompt])\n", + " gr.on(\n", + " triggers=[run_button.click, prompt.submit],\n", + " fn=infer,\n", + " inputs=[\n", + " prompt,\n", + " negative_prompt,\n", + " seed,\n", + " randomize_seed,\n", + " width,\n", + " height,\n", + " guidance_scale,\n", + " num_inference_steps,\n", + " ],\n", + " outputs=[result, seed],\n", + " )\n", + "\n", + "if __name__ == \"__main__\":\n", + " demo.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9d7241-c5f0-43aa-9e63-54bab9beeeb7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}