Instructions to use roshikhan301/NEWONE1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use roshikhan301/NEWONE1 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("roshikhan301/NEWONE1", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import math | |
| from typing import Any, Callable, Optional, Union | |
| import torch | |
| from typing_extensions import TypeAlias | |
| from invokeai.app.services.config.config_default import get_config | |
| from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( | |
| IPAdapterData, | |
| Range, | |
| TextConditioningData, | |
| TextConditioningRegions, | |
| ) | |
| from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData | |
| from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData | |
| ModelForwardCallback: TypeAlias = Union[ | |
| # x, t, conditioning, Optional[cross-attention kwargs] | |
| Callable[ | |
| [torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], | |
| torch.Tensor, | |
| ], | |
| Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], | |
| ] | |
| class InvokeAIDiffuserComponent: | |
| """ | |
| The aim of this component is to provide a single place for code that can be applied identically to | |
| all InvokeAI diffusion procedures. | |
| At the moment it includes the following features: | |
| * Cross attention control ("prompt2prompt") | |
| * Hybrid conditioning (used for inpainting) | |
| """ | |
| debug_thresholding = False | |
| sequential_guidance = False | |
| def __init__( | |
| self, | |
| model, | |
| model_forward_callback: ModelForwardCallback, | |
| ): | |
| """ | |
| :param model: the unet model to pass through to cross attention control | |
| :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) | |
| """ | |
| config = get_config() | |
| self.conditioning = None | |
| self.model = model | |
| self.model_forward_callback = model_forward_callback | |
| self.sequential_guidance = config.sequential_guidance | |
| def do_controlnet_step( | |
| self, | |
| control_data, | |
| sample: torch.Tensor, | |
| timestep: torch.Tensor, | |
| step_index: int, | |
| total_step_count: int, | |
| conditioning_data: TextConditioningData, | |
| ): | |
| down_block_res_samples, mid_block_res_sample = None, None | |
| # control_data should be type List[ControlNetData] | |
| # this loop covers both ControlNet (one ControlNetData in list) | |
| # and MultiControlNet (multiple ControlNetData in list) | |
| for _i, control_datum in enumerate(control_data): | |
| control_mode = control_datum.control_mode | |
| # soft_injection and cfg_injection are the two ControlNet control_mode booleans | |
| # that are combined at higher level to make control_mode enum | |
| # soft_injection determines whether to do per-layer re-weighting adjustment (if True) | |
| # or default weighting (if False) | |
| soft_injection = control_mode == "more_prompt" or control_mode == "more_control" | |
| # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) | |
| # or the default both conditional and unconditional (if False) | |
| cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" | |
| first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) | |
| last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) | |
| # only apply controlnet if current step is within the controlnet's begin/end step range | |
| if step_index >= first_control_step and step_index <= last_control_step: | |
| if cfg_injection: | |
| sample_model_input = sample | |
| else: | |
| # expand the latents input to control model if doing classifier free guidance | |
| # (which I think for now is always true, there is conditional elsewhere that stops execution if | |
| # classifier_free_guidance is <= 1.0 ?) | |
| sample_model_input = torch.cat([sample] * 2) | |
| added_cond_kwargs = None | |
| if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned | |
| if conditioning_data.is_sdxl(): | |
| added_cond_kwargs = { | |
| "text_embeds": conditioning_data.cond_text.pooled_embeds, | |
| "time_ids": conditioning_data.cond_text.add_time_ids, | |
| } | |
| encoder_hidden_states = conditioning_data.cond_text.embeds | |
| encoder_attention_mask = None | |
| else: | |
| if conditioning_data.is_sdxl(): | |
| added_cond_kwargs = { | |
| "text_embeds": torch.cat( | |
| [ | |
| # TODO: how to pad? just by zeros? or even truncate? | |
| conditioning_data.uncond_text.pooled_embeds, | |
| conditioning_data.cond_text.pooled_embeds, | |
| ], | |
| dim=0, | |
| ), | |
| "time_ids": torch.cat( | |
| [ | |
| conditioning_data.uncond_text.add_time_ids, | |
| conditioning_data.cond_text.add_time_ids, | |
| ], | |
| dim=0, | |
| ), | |
| } | |
| ( | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| ) = self._concat_conditionings_for_batch( | |
| conditioning_data.uncond_text.embeds, | |
| conditioning_data.cond_text.embeds, | |
| ) | |
| if isinstance(control_datum.weight, list): | |
| # if controlnet has multiple weights, use the weight for the current step | |
| controlnet_weight = control_datum.weight[step_index] | |
| else: | |
| # if controlnet has a single weight, use it for all steps | |
| controlnet_weight = control_datum.weight | |
| # controlnet(s) inference | |
| down_samples, mid_sample = control_datum.model( | |
| sample=sample_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=encoder_hidden_states, | |
| controlnet_cond=control_datum.image_tensor, | |
| conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale | |
| encoder_attention_mask=encoder_attention_mask, | |
| added_cond_kwargs=added_cond_kwargs, | |
| guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel | |
| return_dict=False, | |
| ) | |
| if cfg_injection: | |
| # Inferred ControlNet only for the conditional batch. | |
| # To apply the output of ControlNet to both the unconditional and conditional batches, | |
| # prepend zeros for unconditional batch | |
| down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] | |
| mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) | |
| if down_block_res_samples is None and mid_block_res_sample is None: | |
| down_block_res_samples, mid_block_res_sample = down_samples, mid_sample | |
| else: | |
| # add controlnet outputs together if have multiple controlnets | |
| down_block_res_samples = [ | |
| samples_prev + samples_curr | |
| for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True) | |
| ] | |
| mid_block_res_sample += mid_sample | |
| return down_block_res_samples, mid_block_res_sample | |
| def do_unet_step( | |
| self, | |
| sample: torch.Tensor, | |
| timestep: torch.Tensor, | |
| conditioning_data: TextConditioningData, | |
| ip_adapter_data: Optional[list[IPAdapterData]], | |
| step_index: int, | |
| total_step_count: int, | |
| down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet | |
| down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter | |
| ): | |
| if self.sequential_guidance: | |
| ( | |
| unconditioned_next_x, | |
| conditioned_next_x, | |
| ) = self._apply_standard_conditioning_sequentially( | |
| x=sample, | |
| sigma=timestep, | |
| conditioning_data=conditioning_data, | |
| ip_adapter_data=ip_adapter_data, | |
| step_index=step_index, | |
| total_step_count=total_step_count, | |
| down_block_additional_residuals=down_block_additional_residuals, | |
| mid_block_additional_residual=mid_block_additional_residual, | |
| down_intrablock_additional_residuals=down_intrablock_additional_residuals, | |
| ) | |
| else: | |
| ( | |
| unconditioned_next_x, | |
| conditioned_next_x, | |
| ) = self._apply_standard_conditioning( | |
| x=sample, | |
| sigma=timestep, | |
| conditioning_data=conditioning_data, | |
| ip_adapter_data=ip_adapter_data, | |
| step_index=step_index, | |
| total_step_count=total_step_count, | |
| down_block_additional_residuals=down_block_additional_residuals, | |
| mid_block_additional_residual=mid_block_additional_residual, | |
| down_intrablock_additional_residuals=down_intrablock_additional_residuals, | |
| ) | |
| return unconditioned_next_x, conditioned_next_x | |
| def _concat_conditionings_for_batch(self, unconditioning, conditioning): | |
| def _pad_conditioning(cond, target_len, encoder_attention_mask): | |
| conditioning_attention_mask = torch.ones( | |
| (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype | |
| ) | |
| if cond.shape[1] < max_len: | |
| conditioning_attention_mask = torch.cat( | |
| [ | |
| conditioning_attention_mask, | |
| torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), | |
| ], | |
| dim=1, | |
| ) | |
| cond = torch.cat( | |
| [ | |
| cond, | |
| torch.zeros( | |
| (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), | |
| device=cond.device, | |
| dtype=cond.dtype, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| if encoder_attention_mask is None: | |
| encoder_attention_mask = conditioning_attention_mask | |
| else: | |
| encoder_attention_mask = torch.cat( | |
| [ | |
| encoder_attention_mask, | |
| conditioning_attention_mask, | |
| ] | |
| ) | |
| return cond, encoder_attention_mask | |
| encoder_attention_mask = None | |
| if unconditioning.shape[1] != conditioning.shape[1]: | |
| max_len = max(unconditioning.shape[1], conditioning.shape[1]) | |
| unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) | |
| conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) | |
| return torch.cat([unconditioning, conditioning]), encoder_attention_mask | |
| # methods below are called from do_diffusion_step and should be considered private to this class. | |
| def _apply_standard_conditioning( | |
| self, | |
| x: torch.Tensor, | |
| sigma: torch.Tensor, | |
| conditioning_data: TextConditioningData, | |
| ip_adapter_data: Optional[list[IPAdapterData]], | |
| step_index: int, | |
| total_step_count: int, | |
| down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet | |
| down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at | |
| the cost of higher memory usage. | |
| """ | |
| x_twice = torch.cat([x] * 2) | |
| sigma_twice = torch.cat([sigma] * 2) | |
| cross_attention_kwargs = {} | |
| if ip_adapter_data is not None: | |
| ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] | |
| # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). | |
| image_prompt_embeds = [ | |
| torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) | |
| for ipa_conditioning in ip_adapter_conditioning | |
| ] | |
| scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] | |
| ip_masks = [ipa.mask for ipa in ip_adapter_data] | |
| regional_ip_data = RegionalIPData( | |
| image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device | |
| ) | |
| cross_attention_kwargs["regional_ip_data"] = regional_ip_data | |
| added_cond_kwargs = None | |
| if conditioning_data.is_sdxl(): | |
| added_cond_kwargs = { | |
| "text_embeds": torch.cat( | |
| [ | |
| # TODO: how to pad? just by zeros? or even truncate? | |
| conditioning_data.uncond_text.pooled_embeds, | |
| conditioning_data.cond_text.pooled_embeds, | |
| ], | |
| dim=0, | |
| ), | |
| "time_ids": torch.cat( | |
| [ | |
| conditioning_data.uncond_text.add_time_ids, | |
| conditioning_data.cond_text.add_time_ids, | |
| ], | |
| dim=0, | |
| ), | |
| } | |
| if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: | |
| # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings | |
| # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems | |
| # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of | |
| # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly | |
| # awkward to handle both standard conditioning and sequential conditioning further up the stack. | |
| regions = [] | |
| for c, r in [ | |
| (conditioning_data.uncond_text, conditioning_data.uncond_regions), | |
| (conditioning_data.cond_text, conditioning_data.cond_regions), | |
| ]: | |
| if r is None: | |
| # Create a dummy mask and range for text conditioning that doesn't have region masks. | |
| _, _, h, w = x.shape | |
| r = TextConditioningRegions( | |
| masks=torch.ones((1, 1, h, w), dtype=x.dtype), | |
| ranges=[Range(start=0, end=c.embeds.shape[1])], | |
| ) | |
| regions.append(r) | |
| cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( | |
| regions=regions, device=x.device, dtype=x.dtype | |
| ) | |
| cross_attention_kwargs["percent_through"] = step_index / total_step_count | |
| both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( | |
| conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds | |
| ) | |
| both_results = self.model_forward_callback( | |
| x_twice, | |
| sigma_twice, | |
| both_conditionings, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| encoder_attention_mask=encoder_attention_mask, | |
| down_block_additional_residuals=down_block_additional_residuals, | |
| mid_block_additional_residual=mid_block_additional_residual, | |
| down_intrablock_additional_residuals=down_intrablock_additional_residuals, | |
| added_cond_kwargs=added_cond_kwargs, | |
| ) | |
| unconditioned_next_x, conditioned_next_x = both_results.chunk(2) | |
| return unconditioned_next_x, conditioned_next_x | |
| def _apply_standard_conditioning_sequentially( | |
| self, | |
| x: torch.Tensor, | |
| sigma, | |
| conditioning_data: TextConditioningData, | |
| ip_adapter_data: Optional[list[IPAdapterData]], | |
| step_index: int, | |
| total_step_count: int, | |
| down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet | |
| mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet | |
| down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter | |
| ): | |
| """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of | |
| slower execution speed. | |
| """ | |
| # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet | |
| # and T2I-Adapter residuals into two chunks. | |
| uncond_down_block, cond_down_block = None, None | |
| if down_block_additional_residuals is not None: | |
| uncond_down_block, cond_down_block = [], [] | |
| for down_block in down_block_additional_residuals: | |
| _uncond_down, _cond_down = down_block.chunk(2) | |
| uncond_down_block.append(_uncond_down) | |
| cond_down_block.append(_cond_down) | |
| uncond_down_intrablock, cond_down_intrablock = None, None | |
| if down_intrablock_additional_residuals is not None: | |
| uncond_down_intrablock, cond_down_intrablock = [], [] | |
| for down_intrablock in down_intrablock_additional_residuals: | |
| _uncond_down, _cond_down = down_intrablock.chunk(2) | |
| uncond_down_intrablock.append(_uncond_down) | |
| cond_down_intrablock.append(_cond_down) | |
| uncond_mid_block, cond_mid_block = None, None | |
| if mid_block_additional_residual is not None: | |
| uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) | |
| ##################### | |
| # Unconditioned pass | |
| ##################### | |
| cross_attention_kwargs = {} | |
| # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. | |
| if ip_adapter_data is not None: | |
| ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] | |
| # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). | |
| image_prompt_embeds = [ | |
| torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) | |
| for ipa_conditioning in ip_adapter_conditioning | |
| ] | |
| scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] | |
| ip_masks = [ipa.mask for ipa in ip_adapter_data] | |
| regional_ip_data = RegionalIPData( | |
| image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device | |
| ) | |
| cross_attention_kwargs["regional_ip_data"] = regional_ip_data | |
| # Prepare SDXL conditioning kwargs for the unconditioned pass. | |
| added_cond_kwargs = None | |
| if conditioning_data.is_sdxl(): | |
| added_cond_kwargs = { | |
| "text_embeds": conditioning_data.uncond_text.pooled_embeds, | |
| "time_ids": conditioning_data.uncond_text.add_time_ids, | |
| } | |
| # Prepare prompt regions for the unconditioned pass. | |
| if conditioning_data.uncond_regions is not None: | |
| cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( | |
| regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype | |
| ) | |
| cross_attention_kwargs["percent_through"] = step_index / total_step_count | |
| # Run unconditioned UNet denoising (i.e. negative prompt). | |
| unconditioned_next_x = self.model_forward_callback( | |
| x, | |
| sigma, | |
| conditioning_data.uncond_text.embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_block_additional_residuals=uncond_down_block, | |
| mid_block_additional_residual=uncond_mid_block, | |
| down_intrablock_additional_residuals=uncond_down_intrablock, | |
| added_cond_kwargs=added_cond_kwargs, | |
| ) | |
| ################### | |
| # Conditioned pass | |
| ################### | |
| cross_attention_kwargs = {} | |
| if ip_adapter_data is not None: | |
| ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] | |
| # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). | |
| image_prompt_embeds = [ | |
| torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) | |
| for ipa_conditioning in ip_adapter_conditioning | |
| ] | |
| scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] | |
| ip_masks = [ipa.mask for ipa in ip_adapter_data] | |
| regional_ip_data = RegionalIPData( | |
| image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device | |
| ) | |
| cross_attention_kwargs["regional_ip_data"] = regional_ip_data | |
| # Prepare SDXL conditioning kwargs for the conditioned pass. | |
| added_cond_kwargs = None | |
| if conditioning_data.is_sdxl(): | |
| added_cond_kwargs = { | |
| "text_embeds": conditioning_data.cond_text.pooled_embeds, | |
| "time_ids": conditioning_data.cond_text.add_time_ids, | |
| } | |
| # Prepare prompt regions for the conditioned pass. | |
| if conditioning_data.cond_regions is not None: | |
| cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( | |
| regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype | |
| ) | |
| cross_attention_kwargs["percent_through"] = step_index / total_step_count | |
| # Run conditioned UNet denoising (i.e. positive prompt). | |
| conditioned_next_x = self.model_forward_callback( | |
| x, | |
| sigma, | |
| conditioning_data.cond_text.embeds, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| down_block_additional_residuals=cond_down_block, | |
| mid_block_additional_residual=cond_mid_block, | |
| down_intrablock_additional_residuals=cond_down_intrablock, | |
| added_cond_kwargs=added_cond_kwargs, | |
| ) | |
| return unconditioned_next_x, conditioned_next_x | |
| def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): | |
| # to scale how much effect conditioning has, calculate the changes it does and then scale that | |
| scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale | |
| combined_next_x = unconditioned_next_x + scaled_delta | |
| return combined_next_x | |