krystv commited on
Commit
ad55fd7
·
verified ·
1 Parent(s): eba54b1

Upload lrf/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lrf/pipeline.py +331 -0
lrf/pipeline.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LatentRecurrentFlow (LRF) - HuggingFace-Compatible Pipeline
3
+
4
+ Provides:
5
+ - LRFPipeline: Full text-to-image and image-editing pipeline
6
+ - Model save/load compatible with HF Hub
7
+ - Diffusers-style API for easy integration
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import json
13
+ import os
14
+ from typing import Optional, List, Union
15
+ from pathlib import Path
16
+
17
+ from lrf.model import LatentRecurrentFlow
18
+ from lrf.training import RectifiedFlowScheduler
19
+
20
+
21
+ class LRFPipeline:
22
+ """
23
+ LatentRecurrentFlow Pipeline for inference.
24
+
25
+ Usage:
26
+ pipe = LRFPipeline.from_pretrained("path/to/model")
27
+ images = pipe("a photo of a cat", num_steps=20)
28
+
29
+ # Or for editing:
30
+ images = pipe("make the cat blue", image=source_image, num_steps=20)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model: LatentRecurrentFlow,
36
+ tokenizer=None,
37
+ device: torch.device = torch.device('cpu'),
38
+ ):
39
+ self.model = model.to(device)
40
+ self.model.eval()
41
+ self.device = device
42
+ self.scheduler = RectifiedFlowScheduler(shift=1.0)
43
+ self.tokenizer = tokenizer
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, path: str, device: str = 'cpu'):
47
+ """Load model from directory or HF Hub."""
48
+ path = Path(path)
49
+ device = torch.device(device)
50
+
51
+ # Load config
52
+ config_path = path / 'config.json'
53
+ if config_path.exists():
54
+ with open(config_path) as f:
55
+ config = json.load(f)
56
+ else:
57
+ config = LatentRecurrentFlow.default_config()
58
+
59
+ # Create model
60
+ model = LatentRecurrentFlow(config)
61
+
62
+ # Load weights if available
63
+ weights_path = path / 'model.safetensors'
64
+ pt_path = path / 'model.pt'
65
+
66
+ if weights_path.exists():
67
+ from safetensors.torch import load_file
68
+ state_dict = load_file(str(weights_path))
69
+ model.load_state_dict(state_dict)
70
+ elif pt_path.exists():
71
+ state_dict = torch.load(str(pt_path), map_location='cpu', weights_only=True)
72
+ if 'model_state' in state_dict:
73
+ model.load_state_dict(state_dict['model_state'])
74
+ else:
75
+ model.load_state_dict(state_dict)
76
+
77
+ return cls(model=model, device=device)
78
+
79
+ def save_pretrained(self, path: str):
80
+ """Save model to directory."""
81
+ path = Path(path)
82
+ path.mkdir(parents=True, exist_ok=True)
83
+
84
+ # Save config
85
+ with open(path / 'config.json', 'w') as f:
86
+ json.dump(self.model.config, f, indent=2)
87
+
88
+ # Save weights
89
+ try:
90
+ from safetensors.torch import save_file
91
+ save_file(self.model.state_dict(), str(path / 'model.safetensors'))
92
+ except ImportError:
93
+ torch.save(self.model.state_dict(), str(path / 'model.pt'))
94
+
95
+ # Save README
96
+ readme = self._generate_readme()
97
+ with open(path / 'README.md', 'w') as f:
98
+ f.write(readme)
99
+
100
+ def _generate_readme(self):
101
+ counts = self.model.count_parameters()
102
+ return f"""---
103
+ tags:
104
+ - image-generation
105
+ - latent-recurrent-flow
106
+ - lrf
107
+ - mobile-first
108
+ - flow-matching
109
+ - recursive-reasoning
110
+ library_name: lrf
111
+ pipeline_tag: text-to-image
112
+ ---
113
+
114
+ # LatentRecurrentFlow (LRF)
115
+
116
+ A novel mobile-first image generation architecture combining:
117
+ - **Recursive Latent Refinement (RLR)** core — HRM-inspired iterative reasoning
118
+ - **Gated Linear Diffusion (GLD)** blocks — O(N) subquadratic spatial mixing
119
+ - **Compact f=16 VAE** with tiny decoder
120
+ - **Rectified flow** training objective
121
+
122
+ ## Model Details
123
+
124
+ | Component | Parameters |
125
+ |-----------|-----------|
126
+ | VAE Encoder | {counts['vae_encoder']:,} |
127
+ | VAE Decoder | {counts['vae_decoder']:,} |
128
+ | Text Encoder | {counts['text_encoder']:,} |
129
+ | Denoising Core | {counts['core']:,} |
130
+ | **Total** | **{counts['total']:,}** |
131
+
132
+ ## Architecture Innovations
133
+
134
+ 1. **Recursive Latent Refinement**: Same parameter blocks applied T_outer × T_inner times,
135
+ giving effective depth of {self.model.config.get('T_outer', 2) * self.model.config.get('T_inner', 4) * self.model.config.get('num_blocks', 4)} layers
136
+ from only {self.model.config.get('num_blocks', 4)} unique parameter sets.
137
+
138
+ 2. **Gated Linear Attention**: O(N) bidirectional scan with token-differential operators
139
+ and 2D locality injection — replaces quadratic self-attention.
140
+
141
+ 3. **IFT Training**: O(1) memory backpropagation through arbitrary recursion depth.
142
+
143
+ ## Usage
144
+
145
+ ```python
146
+ from lrf.pipeline import LRFPipeline
147
+
148
+ pipe = LRFPipeline.from_pretrained("path/to/model")
149
+ images = pipe("a beautiful sunset over the ocean", num_steps=20)
150
+ ```
151
+ """
152
+
153
+ def _simple_tokenize(self, text: str, max_length: int = 77) -> tuple:
154
+ """Simple character-level tokenization for prototype."""
155
+ if self.tokenizer is not None:
156
+ tokens = self.tokenizer(text, max_length=max_length, padding='max_length',
157
+ truncation=True, return_tensors='pt')
158
+ return tokens['input_ids'], tokens['attention_mask']
159
+
160
+ # Fallback: simple hash-based tokenization
161
+ words = text.lower().split()
162
+ token_ids = []
163
+ for word in words:
164
+ # Simple hash to token id
165
+ token_id = hash(word) % 31998 + 1
166
+ token_ids.append(token_id)
167
+
168
+ # Pad/truncate
169
+ if len(token_ids) > max_length:
170
+ token_ids = token_ids[:max_length]
171
+ attention_mask = [1.0] * len(token_ids) + [0.0] * (max_length - len(token_ids))
172
+ token_ids = token_ids + [0] * (max_length - len(token_ids))
173
+
174
+ return (
175
+ torch.tensor([token_ids], dtype=torch.long),
176
+ torch.tensor([attention_mask], dtype=torch.float),
177
+ )
178
+
179
+ @torch.no_grad()
180
+ def __call__(
181
+ self,
182
+ prompt: Union[str, List[str]],
183
+ image: Optional[torch.Tensor] = None,
184
+ num_steps: int = 20,
185
+ cfg_scale: float = 7.5,
186
+ height: int = 256,
187
+ width: int = 256,
188
+ seed: Optional[int] = None,
189
+ ) -> torch.Tensor:
190
+ """
191
+ Generate images from text prompts.
192
+
193
+ Args:
194
+ prompt: Text prompt or list of prompts
195
+ image: Optional source image for editing [B, 3, H, W] in [-1, 1]
196
+ num_steps: Number of sampling steps (4-50, default 20)
197
+ cfg_scale: Classifier-free guidance scale
198
+ height, width: Output image size
199
+ seed: Random seed for reproducibility
200
+
201
+ Returns:
202
+ images: Tensor [B, 3, H, W] in [-1, 1]
203
+ """
204
+ if seed is not None:
205
+ torch.manual_seed(seed)
206
+
207
+ # Handle string input
208
+ if isinstance(prompt, str):
209
+ prompt = [prompt]
210
+
211
+ B = len(prompt)
212
+
213
+ # Tokenize
214
+ all_ids = []
215
+ all_masks = []
216
+ for p in prompt:
217
+ ids, mask = self._simple_tokenize(p)
218
+ all_ids.append(ids)
219
+ all_masks.append(mask)
220
+
221
+ token_ids = torch.cat(all_ids, dim=0).to(self.device)
222
+ attention_mask = torch.cat(all_masks, dim=0).to(self.device)
223
+
224
+ # Encode text
225
+ text_emb, text_global = self.model.encode_text(token_ids, attention_mask)
226
+
227
+ # Compute latent size
228
+ latent_h = height // 16
229
+ latent_w = width // 16
230
+ C = self.model.config['latent_channels']
231
+
232
+ # Handle editing: encode source image
233
+ image_cond = None
234
+ if image is not None:
235
+ image = image.to(self.device)
236
+ with torch.no_grad():
237
+ image_cond, _, _ = self.model.encode_image(image)
238
+
239
+ # Sample
240
+ shape = (B, C, latent_h, latent_w)
241
+ z = self.scheduler.sample(
242
+ self.model, shape, text_emb, text_global,
243
+ num_steps=num_steps, cfg_scale=cfg_scale, device=self.device,
244
+ )
245
+
246
+ # Decode
247
+ images = self.model.decode_latent(z)
248
+
249
+ return images.clamp(-1, 1)
250
+
251
+ def to(self, device):
252
+ """Move pipeline to device."""
253
+ self.device = torch.device(device)
254
+ self.model = self.model.to(self.device)
255
+ return self
256
+
257
+
258
+ class LRFTrainingPipeline:
259
+ """
260
+ Complete training pipeline with staged curriculum.
261
+
262
+ Stages:
263
+ 1. VAE pre-training (or use pre-trained DC-AE)
264
+ 2. Flow matching denoiser training
265
+ 3. Consistency distillation for few-step
266
+ 4. Editing fine-tuning
267
+ """
268
+
269
+ STAGE_CONFIGS = {
270
+ 'vae': {
271
+ 'description': 'Train VAE for image compression',
272
+ 'freeze': [],
273
+ 'train': ['vae'],
274
+ 'lr': 1e-4,
275
+ 'min_steps': 50000,
276
+ },
277
+ 'flow_lowres': {
278
+ 'description': 'Flow matching at 64x64 (composition learning)',
279
+ 'freeze': ['vae'],
280
+ 'train': ['core', 'text_encoder'],
281
+ 'lr': 1e-4,
282
+ 'resolution': 64,
283
+ 'min_steps': 100000,
284
+ },
285
+ 'flow_midres': {
286
+ 'description': 'Flow matching at 256x256 (texture learning)',
287
+ 'freeze': ['vae'],
288
+ 'train': ['core', 'text_encoder'],
289
+ 'lr': 5e-5,
290
+ 'resolution': 256,
291
+ 'min_steps': 200000,
292
+ },
293
+ 'flow_highres': {
294
+ 'description': 'Flow matching at 512x512 (detail learning)',
295
+ 'freeze': ['vae'],
296
+ 'train': ['core', 'text_encoder'],
297
+ 'lr': 2e-5,
298
+ 'resolution': 512,
299
+ 'min_steps': 100000,
300
+ },
301
+ 'consistency': {
302
+ 'description': 'Consistency distillation for 4-step generation',
303
+ 'freeze': ['vae', 'text_encoder'],
304
+ 'train': ['core'],
305
+ 'lr': 1e-5,
306
+ 'min_steps': 50000,
307
+ },
308
+ 'editing': {
309
+ 'description': 'Fine-tune for editing tasks',
310
+ 'freeze': ['vae'],
311
+ 'train': ['core', 'text_encoder'],
312
+ 'lr': 1e-5,
313
+ 'min_steps': 50000,
314
+ },
315
+ }
316
+
317
+ @classmethod
318
+ def get_stage_config(cls, stage_name: str) -> dict:
319
+ return cls.STAGE_CONFIGS.get(stage_name, {})
320
+
321
+ @classmethod
322
+ def get_curriculum(cls) -> list:
323
+ """Return the full training curriculum."""
324
+ return [
325
+ 'vae',
326
+ 'flow_lowres',
327
+ 'flow_midres',
328
+ 'flow_highres',
329
+ 'consistency',
330
+ 'editing',
331
+ ]