Leonardo commited on
Commit
876a77e
·
verified ·
1 Parent(s): 45a3e73

Create flux_lora_tool.py

Browse files
Files changed (1) hide show
  1. flux_lora_tool.py +614 -0
flux_lora_tool.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zhou Protocol FLUX-LoRA Integration Tool
3
+
4
+ This module provides a Smolagents Tool implementation for interacting with FLUX-LoRA-DLC API.
5
+ It enables agents to generate high-quality images with customizable LoRA models.
6
+
7
+ Usage:
8
+ flux_tool = FluxLoRATool()
9
+ agent = CodeAgent(tools=[flux_tool], ...)
10
+ """
11
+
12
+ import os
13
+ import uuid
14
+ import tempfile
15
+ import logging
16
+ from typing import Dict, Any, Optional, List, Union, Tuple
17
+ from dataclasses import dataclass
18
+ import contextlib
19
+ from pathlib import Path
20
+
21
+ # Third-party
22
+ import requests
23
+ from PIL import Image
24
+ from gradio_client import Client
25
+
26
+ # Smolagents
27
+ from smolagents import Tool
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # CONSTANTS AND TYPE DEFINITIONS
31
+ # -----------------------------------------------------------------------------
32
+
33
+ @dataclass
34
+ class LoRAModelInfo:
35
+ """Value object representing LoRA model information."""
36
+ name: str
37
+ description: Optional[str] = None
38
+ example_image_url: Optional[str] = None
39
+
40
+
41
+ @dataclass
42
+ class ImageGenerationResult:
43
+ """Value object representing a generated image result."""
44
+ image_path: str
45
+ seed: int
46
+ metadata: Optional[Dict[str, Any]] = None
47
+
48
+
49
+ # -----------------------------------------------------------------------------
50
+ # CORE TOOL IMPLEMENTATION
51
+ # -----------------------------------------------------------------------------
52
+
53
+ class FluxLoRATool(Tool):
54
+ """
55
+ Tool for generating images using FLUX-LoRA-DLC API.
56
+
57
+ This tool implements the Zhou Protocol integration patterns to provide
58
+ a clean, efficient interface for image generation using LoRA models.
59
+ """
60
+
61
+ name = "flux_lora_generator"
62
+ description = """
63
+ Generates high-quality images using FLUX-LoRA models.
64
+ Can use custom LoRA models, adjust image parameters, and handle image inputs.
65
+ """
66
+ inputs = {
67
+ "prompt": {
68
+ "type": "string",
69
+ "description": "Detailed description of the desired image."
70
+ },
71
+ "image_input": {
72
+ "type": "string",
73
+ "description": "Optional URL or file path to input image for img2img generation.",
74
+ "optional": True
75
+ },
76
+ "image_strength": {
77
+ "type": "float",
78
+ "description": "Strength of input image influence (0.0-1.0), where 1.0 maintains more of original image.",
79
+ "optional": True,
80
+ "default": 0.75
81
+ },
82
+ "cfg_scale": {
83
+ "type": "float",
84
+ "description": "Guidance scale for prompt adherence (1.0-30.0).",
85
+ "optional": True,
86
+ "default": 3.5
87
+ },
88
+ "steps": {
89
+ "type": "integer",
90
+ "description": "Number of sampling steps (10-100).",
91
+ "optional": True,
92
+ "default": 28
93
+ },
94
+ "seed": {
95
+ "type": "integer",
96
+ "description": "Random seed for reproducibility. Use -1 for random seed.",
97
+ "optional": True,
98
+ "default": -1
99
+ },
100
+ "width": {
101
+ "type": "integer",
102
+ "description": "Image width in pixels.",
103
+ "optional": True,
104
+ "default": 1024
105
+ },
106
+ "height": {
107
+ "type": "integer",
108
+ "description": "Image height in pixels.",
109
+ "optional": True,
110
+ "default": 1024
111
+ },
112
+ "lora_scale": {
113
+ "type": "float",
114
+ "description": "LoRA influence scale (0.0-1.0).",
115
+ "optional": True,
116
+ "default": 0.95
117
+ },
118
+ "custom_lora": {
119
+ "type": "string",
120
+ "description": "Custom LoRA model to use. Leave empty for default.",
121
+ "optional": True
122
+ }
123
+ }
124
+ output_type = "string"
125
+
126
+ def __init__(
127
+ self,
128
+ api_url: str = "xkerser/FLUX-LoRA-DLC",
129
+ image_save_dir: Optional[str] = None,
130
+ connection_timeout: int = 60,
131
+ verbose: bool = False
132
+ ):
133
+ """
134
+ Initialize the FLUX-LoRA Tool with Zhou Protocol connection patterns.
135
+
136
+ Args:
137
+ api_url: URL or endpoint ID for the FLUX-LoRA-DLC API
138
+ image_save_dir: Directory to save generated images (created if doesn't exist)
139
+ connection_timeout: API connection timeout in seconds
140
+ verbose: Enable detailed logging
141
+ """
142
+ super().__init__()
143
+
144
+ # Initialize logging
145
+ self.logger = logging.getLogger("flux_lora_tool")
146
+ self.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
147
+
148
+ # Set up client and storage directories
149
+ self.api_url = api_url
150
+ self.connection_timeout = connection_timeout
151
+ self._client = None # Lazy initialization
152
+
153
+ # Set up image storage directory
154
+ self.image_save_dir = image_save_dir or os.path.join(tempfile.gettempdir(), "flux_lora_images")
155
+ os.makedirs(self.image_save_dir, exist_ok=True)
156
+ self.logger.info(f"FluxLoRATool initialized. Images will be saved to: {self.image_save_dir}")
157
+
158
+ @property
159
+ def client(self) -> Client:
160
+ """
161
+ Get or initialize the Gradio client with proper connection handling.
162
+
163
+ Returns:
164
+ Initialized Gradio client
165
+
166
+ Raises:
167
+ ConnectionError: If client initialization fails
168
+ """
169
+ if self._client is None:
170
+ try:
171
+ self._client = Client(
172
+ self.api_url,
173
+ timeout=self.connection_timeout
174
+ )
175
+ self.logger.debug(f"Gradio client initialized for: {self.api_url}")
176
+ except Exception as e:
177
+ error_msg = f"Failed to initialize FLUX-LoRA client: {str(e)}"
178
+ self.logger.error(error_msg)
179
+ raise ConnectionError(error_msg) from e
180
+
181
+ return self._client
182
+
183
+ def _validate_inputs(self, **kwargs) -> Dict[str, Any]:
184
+ """
185
+ Validate and normalize input parameters with Zhou Protocol validation patterns.
186
+
187
+ Args:
188
+ **kwargs: Input parameters
189
+
190
+ Returns:
191
+ Validated and normalized parameters
192
+
193
+ Raises:
194
+ ValueError: If input validation fails
195
+ """
196
+ validated = {}
197
+
198
+ # Required parameter: prompt
199
+ if not kwargs.get("prompt"):
200
+ raise ValueError("Prompt is required for image generation")
201
+ validated["prompt"] = kwargs["prompt"]
202
+
203
+ # Image input handling
204
+ if "image_input" in kwargs and kwargs["image_input"]:
205
+ input_image = kwargs["image_input"]
206
+ # Handle URL vs. local file
207
+ if input_image.startswith(("http://", "https://")):
208
+ # We'll need to download and process this
209
+ validated["image_input"] = self._download_image(input_image)
210
+ else:
211
+ # Check if file exists
212
+ if not os.path.exists(input_image):
213
+ raise ValueError(f"Image file not found: {input_image}")
214
+ validated["image_input"] = input_image
215
+
216
+ # Numeric parameter validation with constraints
217
+ numeric_params = {
218
+ "image_strength": {"min": 0.0, "max": 1.0, "default": 0.75},
219
+ "cfg_scale": {"min": 1.0, "max": 30.0, "default": 3.5},
220
+ "steps": {"min": 10, "max": 100, "default": 28},
221
+ "width": {"min": 128, "max": 2048, "default": 1024},
222
+ "height": {"min": 128, "max": 2048, "default": 1024},
223
+ "lora_scale": {"min": 0.0, "max": 1.0, "default": 0.95}
224
+ }
225
+
226
+ for param, constraints in numeric_params.items():
227
+ if param in kwargs and kwargs[param] is not None:
228
+ value = kwargs[param]
229
+
230
+ # Type conversion if needed
231
+ if param in ["steps", "width", "height"]:
232
+ try:
233
+ value = int(value)
234
+ except (ValueError, TypeError):
235
+ raise ValueError(f"Parameter '{param}' must be an integer")
236
+ else:
237
+ try:
238
+ value = float(value)
239
+ except (ValueError, TypeError):
240
+ raise ValueError(f"Parameter '{param}' must be a number")
241
+
242
+ # Range validation
243
+ if value < constraints["min"] or value > constraints["max"]:
244
+ raise ValueError(
245
+ f"Parameter '{param}' must be between {constraints['min']} and {constraints['max']}"
246
+ )
247
+
248
+ validated[param] = value
249
+ else:
250
+ validated[param] = constraints["default"]
251
+
252
+ # Special handling for seed
253
+ if "seed" in kwargs and kwargs["seed"] is not None:
254
+ try:
255
+ seed = int(kwargs["seed"])
256
+ # -1 indicates random seed
257
+ if seed == -1:
258
+ try:
259
+ seed = self._get_random_seed()
260
+ except Exception as e:
261
+ self.logger.warning(f"Failed to get random seed from API: {e}")
262
+ # Fallback to Python's random
263
+ import random
264
+ seed = random.randint(0, 2**32 - 1)
265
+ validated["seed"] = seed
266
+ except (ValueError, TypeError):
267
+ raise ValueError("Seed must be an integer")
268
+ else:
269
+ # Default to random seed
270
+ validated["seed"] = self._get_random_seed()
271
+
272
+ # Custom LoRA handling
273
+ if "custom_lora" in kwargs and kwargs["custom_lora"]:
274
+ validated["custom_lora"] = kwargs["custom_lora"]
275
+
276
+ return validated
277
+
278
+ def _download_image(self, url: str) -> str:
279
+ """
280
+ Download image from URL and save to local file.
281
+
282
+ Args:
283
+ url: Image URL
284
+
285
+ Returns:
286
+ Local file path
287
+
288
+ Raises:
289
+ ConnectionError: If download fails
290
+ """
291
+ try:
292
+ response = requests.get(url, stream=True, timeout=30)
293
+ response.raise_for_status()
294
+
295
+ # Generate temporary file path
296
+ file_ext = self._guess_extension(response.headers.get("Content-Type", ""))
297
+ temp_path = os.path.join(
298
+ self.image_save_dir,
299
+ f"input_{uuid.uuid4().hex}{file_ext}"
300
+ )
301
+
302
+ # Save image
303
+ with open(temp_path, "wb") as f:
304
+ for chunk in response.iter_content(chunk_size=8192):
305
+ f.write(chunk)
306
+
307
+ self.logger.debug(f"Downloaded image from {url} to {temp_path}")
308
+ return temp_path
309
+
310
+ except Exception as e:
311
+ error_msg = f"Failed to download image from {url}: {str(e)}"
312
+ self.logger.error(error_msg)
313
+ raise ConnectionError(error_msg) from e
314
+
315
+ def _guess_extension(self, content_type: str) -> str:
316
+ """
317
+ Guess file extension from content type.
318
+
319
+ Args:
320
+ content_type: HTTP Content-Type header
321
+
322
+ Returns:
323
+ File extension (with dot)
324
+ """
325
+ content_type = content_type.lower()
326
+ if "jpeg" in content_type or "jpg" in content_type:
327
+ return ".jpg"
328
+ elif "png" in content_type:
329
+ return ".png"
330
+ elif "webp" in content_type:
331
+ return ".webp"
332
+ elif "gif" in content_type:
333
+ return ".gif"
334
+ else:
335
+ return ".png" # Default to PNG
336
+
337
+ def _get_random_seed(self) -> int:
338
+ """
339
+ Get a random seed from the API.
340
+
341
+ Returns:
342
+ Random seed value
343
+
344
+ Raises:
345
+ RuntimeError: If random seed retrieval fails
346
+ """
347
+ try:
348
+ result = self.client.predict(api_name="/get_random_value")
349
+ if isinstance(result, (int, float)):
350
+ return int(result)
351
+ else:
352
+ raise ValueError(f"Unexpected result type: {type(result)}")
353
+ except Exception as e:
354
+ # Just log and re-raise as we have fallback in the validation method
355
+ self.logger.warning(f"Failed to get random seed: {e}")
356
+ raise
357
+
358
+ def _handle_custom_lora(self, custom_lora: Optional[str]) -> None:
359
+ """
360
+ Add or remove custom LoRA model.
361
+
362
+ Args:
363
+ custom_lora: Custom LoRA model string
364
+
365
+ Raises:
366
+ RuntimeError: If LoRA handling fails
367
+ """
368
+ if not custom_lora:
369
+ # Remove any existing custom LoRA
370
+ try:
371
+ self.client.predict(api_name="/remove_custom_lora")
372
+ self.logger.debug("Removed custom LoRA")
373
+ except Exception as e:
374
+ error_msg = f"Failed to remove custom LoRA: {str(e)}"
375
+ self.logger.error(error_msg)
376
+ raise RuntimeError(error_msg) from e
377
+ else:
378
+ # Add custom LoRA
379
+ try:
380
+ self.client.predict(
381
+ custom_lora=custom_lora,
382
+ api_name="/add_custom_lora"
383
+ )
384
+ self.logger.debug(f"Added custom LoRA: {custom_lora}")
385
+ except Exception as e:
386
+ error_msg = f"Failed to add custom LoRA '{custom_lora}': {str(e)}"
387
+ self.logger.error(error_msg)
388
+ raise RuntimeError(error_msg) from e
389
+
390
+ def forward(
391
+ self,
392
+ prompt: str,
393
+ image_input: Optional[str] = None,
394
+ image_strength: Optional[float] = None,
395
+ cfg_scale: Optional[float] = None,
396
+ steps: Optional[int] = None,
397
+ seed: Optional[int] = None,
398
+ width: Optional[int] = None,
399
+ height: Optional[int] = None,
400
+ lora_scale: Optional[float] = None,
401
+ custom_lora: Optional[str] = None
402
+ ) -> str:
403
+ """
404
+ Generate an image with FLUX-LoRA.
405
+
406
+ Args:
407
+ prompt: Text description of the desired image
408
+ image_input: Optional path or URL to input image for img2img
409
+ image_strength: Strength of input image influence (0.0-1.0)
410
+ cfg_scale: Guidance scale (1.0-30.0)
411
+ steps: Number of sampling steps (10-100)
412
+ seed: Random seed (-1 for random)
413
+ width: Image width in pixels (128-2048)
414
+ height: Image height in pixels (128-2048)
415
+ lora_scale: LoRA influence scale (0.0-1.0)
416
+ custom_lora: Custom LoRA model to use
417
+
418
+ Returns:
419
+ Formatted string with image generation results
420
+
421
+ Raises:
422
+ ValueError: If input validation fails
423
+ ConnectionError: If API communication fails
424
+ RuntimeError: If image generation fails
425
+ """
426
+ # Step 1: Validate and normalize inputs
427
+ try:
428
+ params = self._validate_inputs(
429
+ prompt=prompt,
430
+ image_input=image_input,
431
+ image_strength=image_strength,
432
+ cfg_scale=cfg_scale,
433
+ steps=steps,
434
+ seed=seed,
435
+ width=width,
436
+ height=height,
437
+ lora_scale=lora_scale,
438
+ custom_lora=custom_lora
439
+ )
440
+ self.logger.debug(f"Validated parameters: {params}")
441
+ except ValueError as e:
442
+ return f"Parameter validation failed: {str(e)}"
443
+
444
+ # Step 2: Handle custom LoRA if specified
445
+ if "custom_lora" in params:
446
+ try:
447
+ custom_lora_value = params.pop("custom_lora")
448
+ self._handle_custom_lora(custom_lora_value)
449
+ except RuntimeError as e:
450
+ return f"Custom LoRA setup failed: {str(e)}"
451
+
452
+ # Step 3: Generate image
453
+ try:
454
+ # Prepare image input if provided
455
+ img_param = None
456
+ if "image_input" in params and params["image_input"]:
457
+ from gradio_client import handle_file
458
+ img_param = handle_file(params.pop("image_input"))
459
+
460
+ # Call the API
461
+ generation_args = {
462
+ "prompt": params["prompt"],
463
+ "image_strength": params["image_strength"],
464
+ "cfg_scale": params["cfg_scale"],
465
+ "steps": params["steps"],
466
+ "randomize_seed": False, # We handle seed explicitly
467
+ "seed": params["seed"],
468
+ "width": params["width"],
469
+ "height": params["height"],
470
+ "lora_scale": params["lora_scale"],
471
+ }
472
+
473
+ # Add image input if available
474
+ if img_param:
475
+ generation_args["image_input"] = img_param
476
+
477
+ self.logger.info(f"Generating image with params: {generation_args}")
478
+ result = self.client.predict(
479
+ api_name="/run_lora",
480
+ **generation_args
481
+ )
482
+
483
+ # Process result
484
+ if isinstance(result, tuple) and len(result) >= 2:
485
+ image_path, actual_seed = result[0], result[1]
486
+
487
+ # Save image to our directory
488
+ try:
489
+ output_path = self._save_image(image_path)
490
+ image_result = ImageGenerationResult(
491
+ image_path=output_path,
492
+ seed=int(actual_seed)
493
+ )
494
+ return self._format_result(image_result, params["prompt"])
495
+ except Exception as e:
496
+ self.logger.error(f"Failed to save generated image: {e}")
497
+ return f"Image generated but failed to save: {str(e)}"
498
+ else:
499
+ raise ValueError(f"Unexpected API response format: {result}")
500
+
501
+ except Exception as e:
502
+ error_msg = f"Image generation failed: {str(e)}"
503
+ self.logger.error(error_msg)
504
+ return error_msg
505
+
506
+ def _save_image(self, image_path: str) -> str:
507
+ """
508
+ Save generated image to specified directory.
509
+
510
+ Args:
511
+ image_path: Path to generated image from API
512
+
513
+ Returns:
514
+ Path to saved image
515
+
516
+ Raises:
517
+ IOError: If image saving fails
518
+ """
519
+ try:
520
+ # Load the image
521
+ img = Image.open(image_path)
522
+
523
+ # Generate timestamp-based filename
524
+ timestamp = uuid.uuid4().hex[:8]
525
+ output_filename = f"flux_lora_{timestamp}.png"
526
+ output_path = os.path.join(self.image_save_dir, output_filename)
527
+
528
+ # Save to our directory
529
+ img.save(output_path)
530
+ self.logger.debug(f"Saved image to {output_path}")
531
+
532
+ return output_path
533
+
534
+ except Exception as e:
535
+ error_msg = f"Failed to save image: {str(e)}"
536
+ self.logger.error(error_msg)
537
+ raise IOError(error_msg) from e
538
+
539
+ def _format_result(self, result: ImageGenerationResult, prompt: str) -> str:
540
+ """
541
+ Format the image generation result as a string.
542
+
543
+ Args:
544
+ result: Image generation result
545
+ prompt: Original prompt
546
+
547
+ Returns:
548
+ Formatted string with generation details
549
+ """
550
+ lines = [
551
+ f"📷 Image generated successfully!",
552
+ f"🖼️ Image saved to: {result.image_path}",
553
+ f"🌱 Seed used: {result.seed}",
554
+ f"📝 Original prompt: {prompt}",
555
+ ]
556
+
557
+ # Add metadata if available
558
+ if result.metadata:
559
+ lines.append("📊 Additional metadata:")
560
+ for key, value in result.metadata.items():
561
+ lines.append(f" - {key}: {value}")
562
+
563
+ return "\n".join(lines)
564
+
565
+
566
+ # -----------------------------------------------------------------------------
567
+ # UTILITY FUNCTIONS
568
+ # -----------------------------------------------------------------------------
569
+
570
+ def download_image(url: str, output_dir: Optional[str] = None) -> str:
571
+ """
572
+ Standalone utility to download an image from a URL.
573
+
574
+ Args:
575
+ url: Image URL
576
+ output_dir: Directory to save image (created if doesn't exist)
577
+
578
+ Returns:
579
+ Path to downloaded image
580
+
581
+ Raises:
582
+ ValueError: If URL is invalid
583
+ ConnectionError: If download fails
584
+ IOError: If saving fails
585
+ """
586
+ if not url.startswith(("http://", "https://")):
587
+ raise ValueError(f"Invalid URL: {url}")
588
+
589
+ # Setup output directory
590
+ if output_dir is None:
591
+ output_dir = os.path.join(tempfile.gettempdir(), "flux_lora_images")
592
+ os.makedirs(output_dir, exist_ok=True)
593
+
594
+ try:
595
+ # Download image
596
+ response = requests.get(url, stream=True, timeout=30)
597
+ response.raise_for_status()
598
+
599
+ # Determine file extension
600
+ content_type = response.headers.get("Content-Type", "")
601
+ ext = ".jpg" if "jpeg" in content_type.lower() else ".png"
602
+
603
+ # Save image
604
+ output_path = os.path.join(output_dir, f"download_{uuid.uuid4().hex}{ext}")
605
+ with open(output_path, "wb") as f:
606
+ for chunk in response.iter_content(chunk_size=8192):
607
+ f.write(chunk)
608
+
609
+ return output_path
610
+
611
+ except requests.RequestException as e:
612
+ raise ConnectionError(f"Failed to download image: {str(e)}")
613
+ except IOError as e:
614
+ raise IOError(f"Failed to save image: {str(e)}")