Instructions to use BiliSakura/IntrisicWeather-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/IntrisicWeather-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/IntrisicWeather-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright 2026 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def extract_patch_tokens_min_windows( | |
| images: torch.Tensor, | |
| model: nn.Module, | |
| processor, | |
| window_size: int = 224, | |
| device: str | torch.device = "cuda", | |
| ) -> torch.Tensor: | |
| r""" | |
| Tile each image with a minimal window set and return averaged DINO patch tokens. | |
| Args: | |
| images (`torch.Tensor`): Batch of RGB images `(B, C, H, W)`. | |
| model: DINO vision transformer. | |
| processor: Hugging Face image processor for DINO. | |
| window_size (`int`): Sliding-window size in pixels. | |
| device: Device for intermediate tensors. | |
| Returns: | |
| `torch.Tensor` of shape `(B, H//patch, W//patch, hidden_size)`. | |
| """ | |
| batch_size, _, height, width = images.shape | |
| hidden_size = model.config.hidden_size | |
| patch_size = model.config.patch_size | |
| token_avgs = [] | |
| for batch_idx in range(batch_size): | |
| image = images[batch_idx] | |
| if image.max() <= 1.0: | |
| image_np = (image.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8") | |
| else: | |
| image_np = image.permute(1, 2, 0).cpu().numpy().clip(0, 255).astype("uint8") | |
| token_sum = torch.zeros((height // patch_size, width // patch_size, hidden_size), device=device) | |
| token_count = torch.zeros((height // patch_size, width // patch_size, 1), device=device) | |
| num_y = (height + window_size - 1) // window_size | |
| num_x = (width + window_size - 1) // window_size | |
| y_positions = [index * window_size for index in range(num_y - 1)] + [height - window_size] | |
| x_positions = [index * window_size for index in range(num_x - 1)] + [width - window_size] | |
| for y in y_positions: | |
| for x in x_positions: | |
| patch = image_np[y : y + window_size, x : x + window_size, :] | |
| inputs = processor(images=patch, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| patch_tokens = outputs.last_hidden_state[:, 1:, :] | |
| patch_tokens = patch_tokens.reshape( | |
| 1, window_size // patch_size, window_size // patch_size, hidden_size | |
| ).squeeze(0) | |
| y0, x0 = y // patch_size, x // patch_size | |
| y1, x1 = y0 + window_size // patch_size, x0 + window_size // patch_size | |
| token_sum[y0:y1, x0:x1, :] += patch_tokens | |
| token_count[y0:y1, x0:x1, 0] += 1 | |
| token_avgs.append(token_sum / token_count) | |
| return torch.stack(token_avgs, dim=0) | |
| class LayerNorm2d(nn.Module): | |
| def __init__(self, channels: int) -> None: | |
| super().__init__() | |
| self.norm = nn.LayerNorm([channels]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.norm(x) | |
| return x.permute(0, 3, 1, 2) | |
| class IMAA(nn.Module): | |
| r""" | |
| Intrinsic Map-Aware Attention (IMAA) gating module. | |
| Produces per-map attention biases from DINO patch tokens and learnable map embeddings. | |
| """ | |
| def __init__( | |
| self, | |
| dino_model: Optional[nn.Module] = None, | |
| processor=None, | |
| num_maps: int = 5, | |
| map_embedding_dim: int = 256, | |
| common_dim: int = 128, | |
| conv_channels: Optional[list[int]] = None, | |
| dino_patch_dim: int = 768, | |
| ) -> None: | |
| super().__init__() | |
| conv_channels = conv_channels or [128, 64] | |
| self.dino = dino_model | |
| self.processor = processor | |
| if self.dino is not None: | |
| self.dino.eval() | |
| for param in self.dino.parameters(): | |
| param.requires_grad = False | |
| self.num_maps = num_maps | |
| self.map_embedding_dim = map_embedding_dim | |
| self.common_dim = common_dim | |
| self.dino_patch_dim = dino_patch_dim | |
| self.map_embedding = nn.Parameter(torch.randn(num_maps, map_embedding_dim)) | |
| self.dino_proj = nn.Conv2d(dino_patch_dim, common_dim, kernel_size=1) | |
| self.map_proj = nn.Linear(map_embedding_dim, common_dim) | |
| self.fusion_layer = nn.Sequential( | |
| nn.Conv2d(common_dim * 2, common_dim, 1), | |
| LayerNorm2d(common_dim), | |
| nn.ReLU(), | |
| nn.Conv2d(common_dim, common_dim, 3, padding=1), | |
| ) | |
| conv_layers: list[nn.Module] = [] | |
| in_channels = common_dim | |
| for out_channels in conv_channels: | |
| conv_layers.extend([nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()]) | |
| in_channels = out_channels | |
| conv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1)) | |
| self.conv_head = nn.Sequential(*conv_layers) | |
| def forward( | |
| self, | |
| image: Optional[torch.Tensor] = None, | |
| patch_tokens: Optional[torch.Tensor] = None, | |
| output_size: Optional[Tuple[int, int]] = None, | |
| map_ids: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if patch_tokens is None: | |
| if self.dino is None or image is None: | |
| raise ValueError("Either `patch_tokens` or (`image` and a frozen DINO model) must be provided.") | |
| patch_tokens = extract_patch_tokens_min_windows( | |
| image, self.dino, self.processor, window_size=224, device=image.device | |
| ) | |
| dino_feat_map = patch_tokens.permute(0, 3, 1, 2) | |
| dino_proj = self.dino_proj(dino_feat_map) | |
| map_emb = self.map_embedding[map_ids] | |
| map_proj = self.map_proj(map_emb).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, dino_proj.size(2), dino_proj.size(3)) | |
| fused_map = self.fusion_layer(torch.cat([dino_proj, map_proj], dim=1)) | |
| raw_gating_map = self.conv_head(fused_map) | |
| aligned_map = ( | |
| F.interpolate(raw_gating_map, size=output_size, mode="bilinear", align_corners=False) | |
| if output_size is not None | |
| else raw_gating_map | |
| ) | |
| return torch.sigmoid(aligned_map) | |
| def build_attn_mask( | |
| w_gating: torch.Tensor, | |
| text_len: int, | |
| img_len: int, | |
| lam: float, | |
| ) -> torch.Tensor: | |
| r""" | |
| Build an additive attention mask from IMAA gating weights. | |
| Args: | |
| w_gating (`torch.Tensor`): Gating map `[B, 1, H, W]` or flattened `[B, img_len]`. | |
| text_len (`int`): Number of text tokens prepended to image tokens. | |
| img_len (`int`): Expected number of image tokens. | |
| lam (`float`): Mask scaling factor. | |
| Returns: | |
| Attention bias tensor shaped for SD3 joint attention. | |
| """ | |
| batch_size = w_gating.shape[0] | |
| total_len = text_len + img_len | |
| if w_gating.dim() == 4: | |
| w_gating = w_gating.view(batch_size, -1) | |
| gating = lam * w_gating | |
| actual_img_len = gating.shape[1] | |
| if actual_img_len != img_len: | |
| if actual_img_len > img_len: | |
| gating = gating[:, :img_len] | |
| else: | |
| padding = torch.zeros(batch_size, img_len - actual_img_len, device=gating.device, dtype=gating.dtype) | |
| gating = torch.cat([gating, padding], dim=1) | |
| col_bias = torch.zeros(batch_size, total_len, device=w_gating.device, dtype=w_gating.dtype) | |
| col_bias[:, text_len:] = gating | |
| return col_bias.view(batch_size, 1, 1, total_len) | |