Spaces:
Sleeping
Sleeping
| # --- | |
| # deploy: true | |
| # --- | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import modal | |
| app = modal.App(name="dreambooth-lora-flux") | |
| image = modal.Image.debian_slim(python_version="3.10").pip_install( | |
| "accelerate==0.31.0", | |
| "datasets==3.6.0", | |
| "pillow", | |
| "fastapi[standard]==0.115.4", | |
| "ftfy~=6.1.0", | |
| "gradio~=5.5.0", | |
| "huggingface-hub==0.32.4", | |
| "hf_transfer==0.1.8", | |
| "numpy<2", | |
| "peft==0.11.1", | |
| "pydantic==2.9.2", | |
| "sentencepiece>=0.1.91,!=0.1.92", | |
| "smart_open~=6.4.0", | |
| "starlette==0.41.2", | |
| "transformers~=4.41.2", | |
| "torch~=2.2.0", | |
| "torchvision~=0.16", | |
| "triton~=2.2.0", | |
| "wandb==0.17.6", | |
| ) | |
| GIT_SHA = "e649678bf55aeaa4b60bd1f68b1ee726278c0304" # specify the commit to fetch | |
| image = ( | |
| image.apt_install("git") | |
| # Perform a shallow fetch of just the target `diffusers` commit, checking out | |
| # the commit in the container's home directory, /root. Then install `diffusers` | |
| .run_commands( | |
| "cd /root && git init .", | |
| "cd /root && git remote add origin https://github.com/huggingface/diffusers", | |
| f"cd /root && git fetch --depth=1 origin {GIT_SHA} && git checkout {GIT_SHA}", | |
| "cd /root && pip install -e .", | |
| ) | |
| ) | |
| # ### Configuration with `dataclass`es | |
| # Machine learning apps often have a lot of configuration information. | |
| # We collect up all of our configuration into dataclasses to avoid scattering special/magic values throughout code. | |
| class SharedConfig: | |
| """Configuration information shared across project components.""" | |
| # The instance name is the "proper noun" we're teaching the model | |
| instance_name: str = "Qwerty" | |
| # That proper noun is usually a member of some class (person, bird), | |
| # and sharing that information with the model helps it generalize better. | |
| class_name: str = "Golden Retriever" | |
| # identifier for pretrained models on Hugging Face | |
| model_name: str = "black-forest-labs/FLUX.1-dev" | |
| # ### Storing data created by our app with `modal.Volume` | |
| # The tools we've used so far work well for fetching external information, | |
| # which defines the environment our app runs in, | |
| # but what about data that we create or modify during the app's execution? | |
| # A persisted [`modal.Volume`](https://modal.com/docs/guide/volumes) can store and share data across Modal Apps and Functions. | |
| # We'll use one to store both the original and fine-tuned weights we create during training | |
| # and then load them back in for inference. | |
| image = image.env( | |
| {"HF_HUB_ENABLE_HF_TRANSFER": "1"} # turn on faster downloads from HF | |
| ) | |
| def load_images_from_hf_dataset(dataset_id: str, hf_token: str) -> Path: | |
| """Load images from a HuggingFace dataset.""" | |
| import PIL.Image | |
| from datasets import load_dataset | |
| img_path = Path("/img") | |
| img_path.mkdir(parents=True, exist_ok=True) | |
| # Load dataset from HuggingFace | |
| dataset = load_dataset(dataset_id, token=hf_token, split="train") | |
| for ii, example in enumerate(dataset): | |
| # Assume the dataset has an 'image' column | |
| if 'image' in example: | |
| image = example['image'] | |
| if isinstance(image, PIL.Image.Image): | |
| image.save(img_path / f"{ii}.png") | |
| else: | |
| # Handle other image formats | |
| pil_image = PIL.Image.open(image) | |
| pil_image.save(img_path / f"{ii}.png") | |
| else: | |
| print(f"Warning: No 'image' field found in dataset example {ii}") | |
| print(f"{len(dataset)} images loaded from HuggingFace dataset") | |
| return img_path | |
| # ## Stateless API Training Function | |
| class APITrainConfig: | |
| """Configuration for the API training function.""" | |
| # Basic model info | |
| model_name: str = "black-forest-labs/FLUX.1-dev" | |
| # Training prompt components | |
| instance_name: str = "subject" | |
| class_name: str = "person" | |
| prefix: str = "a photo of" | |
| postfix: str = "" | |
| # Training hyperparameters | |
| resolution: int = 512 | |
| train_batch_size: int = 3 | |
| rank: int = 16 # lora rank | |
| gradient_accumulation_steps: int = 1 | |
| learning_rate: float = 4e-4 | |
| lr_scheduler: str = "constant" | |
| lr_warmup_steps: int = 0 | |
| max_train_steps: int = 500 | |
| checkpointing_steps: int = 1000 | |
| seed: int = 117 | |
| def train_lora_stateless( | |
| dataset_id: str, | |
| hf_token: str, | |
| output_repo: str, | |
| instance_name: Optional[str] = None, | |
| class_name: Optional[str] = None, | |
| max_train_steps: int = 500, | |
| ): | |
| """ | |
| Stateless LoRA training function that reads from HF dataset and uploads to HF repo. | |
| Args: | |
| dataset_id: HuggingFace dataset ID (e.g., "username/dataset-name") | |
| hf_token: HuggingFace API token | |
| output_repo: HuggingFace repository to upload the trained LoRA to | |
| instance_name: Name of the subject (optional, defaults to "subject") | |
| class_name: Class of the subject (optional, defaults to "person") | |
| max_train_steps: Number of training steps | |
| """ | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import torch | |
| from accelerate.utils import write_basic_config | |
| from diffusers import DiffusionPipeline | |
| from huggingface_hub import snapshot_download, upload_folder, login, create_repo | |
| # Login to HuggingFace | |
| login(token=hf_token) | |
| # Create temporary directories | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = Path(temp_dir) | |
| model_dir = temp_path / "model" | |
| output_dir = temp_path / "output" | |
| # Download base model | |
| print("📥 Downloading base model...") | |
| snapshot_download( | |
| "black-forest-labs/FLUX.1-dev", | |
| local_dir=str(model_dir), | |
| ignore_patterns=["*.pt", "*.bin"], # using safetensors | |
| token=hf_token | |
| ) | |
| # Load and validate model | |
| DiffusionPipeline.from_pretrained(str(model_dir), torch_dtype=torch.bfloat16) | |
| print("✅ Base model loaded successfully") | |
| # Load training images from HF dataset | |
| print(f"📥 Loading images from dataset: {dataset_id}") | |
| img_path = load_images_from_hf_dataset(dataset_id, hf_token) | |
| # Set up training configuration | |
| config = APITrainConfig( | |
| instance_name=instance_name or "subject", | |
| class_name=class_name or "person", | |
| max_train_steps=max_train_steps | |
| ) | |
| # Set up hugging face accelerate library for fast training | |
| write_basic_config(mixed_precision="bf16") | |
| # Define the training prompt | |
| instance_phrase = f"{config.instance_name} the {config.class_name}" | |
| prompt = f"{config.prefix} {instance_phrase} {config.postfix}".strip() | |
| print(f"🎯 Training prompt: {prompt}") | |
| print(f"🚀 Starting training for {max_train_steps} steps...") | |
| # Execute training subprocess | |
| def _exec_subprocess(cmd: list[str]): | |
| """Executes subprocess and prints log to terminal while subprocess is running.""" | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| with process.stdout as pipe: | |
| for line in iter(pipe.readline, b""): | |
| line_str = line.decode() | |
| print(f"{line_str}", end="") | |
| if exitcode := process.wait() != 0: | |
| raise subprocess.CalledProcessError(exitcode, "\n".join(cmd)) | |
| # Run training | |
| _exec_subprocess([ | |
| "accelerate", | |
| "launch", | |
| "examples/dreambooth/train_dreambooth_lora_flux.py", | |
| "--mixed_precision=bf16", | |
| f"--pretrained_model_name_or_path={model_dir}", | |
| f"--instance_data_dir={img_path}", | |
| f"--output_dir={output_dir}", | |
| f"--instance_prompt={prompt}", | |
| f"--resolution={config.resolution}", | |
| f"--train_batch_size={config.train_batch_size}", | |
| f"--gradient_accumulation_steps={config.gradient_accumulation_steps}", | |
| f"--learning_rate={config.learning_rate}", | |
| f"--lr_scheduler={config.lr_scheduler}", | |
| f"--lr_warmup_steps={config.lr_warmup_steps}", | |
| f"--max_train_steps={config.max_train_steps}", | |
| f"--checkpointing_steps={config.checkpointing_steps}", | |
| f"--seed={config.seed}", | |
| ]) | |
| print("✅ Training completed!") | |
| # Upload trained LoRA to HuggingFace repository | |
| print(f"📤 Uploading LoRA to repository: {output_repo}") | |
| # Create repository if it doesn't exist | |
| create_repo( | |
| repo_id=output_repo, | |
| repo_type="model", | |
| token=hf_token, | |
| exist_ok=True | |
| ) | |
| # print contents of output_dir | |
| print(f"Contents of {output_dir}:") | |
| for file in output_dir.iterdir(): | |
| print(file) | |
| upload_folder( | |
| folder_path=str(output_dir), | |
| repo_id=output_repo, | |
| repo_type="model", | |
| token=hf_token, | |
| commit_message=f"Add LoRA trained on {dataset_id}", | |
| ) | |
| print(f"🎉 Successfully uploaded LoRA to {output_repo}") | |
| return { | |
| "status": "success", | |
| "message": f"LoRA training completed and uploaded to {output_repo}", | |
| "dataset_used": dataset_id, | |
| "training_steps": max_train_steps, | |
| "training_prompt": prompt | |
| } | |
| # ## API Endpoints with Job ID System | |
| def api_start_training(item: dict): | |
| """ | |
| Start LoRA training and return a job ID. | |
| Expected JSON payload: | |
| { | |
| "dataset_id": "username/dataset-name", | |
| "hf_token": "hf_...", | |
| "output_repo": "username/output-repo", | |
| "instance_name": "optional_subject_name", | |
| "class_name": "optional_class_name", | |
| "max_train_steps": 500 | |
| } | |
| """ | |
| try: | |
| # Extract required parameters | |
| dataset_id = item["dataset_id"] | |
| hf_token = item["hf_token"] | |
| output_repo = item["output_repo"] | |
| # Extract optional parameters | |
| instance_name = item.get("instance_name") | |
| class_name = item.get("class_name") | |
| max_train_steps = item.get("max_train_steps", 500) | |
| # Start training (non-blocking) | |
| call_handle = train_lora_stateless.spawn( | |
| dataset_id=dataset_id, | |
| hf_token=hf_token, | |
| output_repo=output_repo, | |
| instance_name=instance_name, | |
| class_name=class_name, | |
| max_train_steps=max_train_steps | |
| ) | |
| job_id = call_handle.object_id | |
| return { | |
| "status": "started", | |
| "job_id": job_id, | |
| "message": "Training job started successfully", | |
| "dataset_id": dataset_id, | |
| "output_repo": output_repo, | |
| "max_train_steps": max_train_steps | |
| } | |
| except KeyError as e: | |
| return { | |
| "status": "error", | |
| "message": f"Missing required parameter: {e}" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Failed to start training: {str(e)}" | |
| } | |
| def api_job_status(job_id: str): | |
| """ | |
| Check the status of a training job. | |
| Pass job_id as a query parameter: /job_status?job_id=xyz | |
| """ | |
| try: | |
| from modal.functions import FunctionCall | |
| # Get the function call handle | |
| call_handle = FunctionCall.from_id(job_id) | |
| if call_handle is None: | |
| return { | |
| "status": "error", | |
| "message": "Job not found" | |
| } | |
| # Check if the job is finished | |
| try: | |
| result = call_handle.get(timeout=0) # Non-blocking check | |
| return { | |
| "status": "completed", | |
| "result": result | |
| } | |
| except TimeoutError: | |
| return { | |
| "status": "running", | |
| "message": "Job is still running" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "failed", | |
| "message": f"Job failed: {str(e)}" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Error checking job status: {str(e)}" | |
| } | |
| class InferenceConfig: | |
| """Configuration for inference.""" | |
| num_inference_steps: int = 20 | |
| guidance_scale: float = 7.5 | |
| width: int = 512 | |
| height: int = 512 | |
| def generate_images_stateless( | |
| hf_token: str, | |
| lora_repo: str, | |
| prompts: list[str], | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7.5, | |
| width: int = 512, | |
| height: int = 512, | |
| ): | |
| """ | |
| Stateless function to generate images using a LoRA from HuggingFace. | |
| Args: | |
| hf_token: HuggingFace API token | |
| lora_repo: HuggingFace repository containing the LoRA (e.g., "username/my-lora") | |
| prompts: List of text prompts to generate images for | |
| num_inference_steps: Number of denoising steps | |
| guidance_scale: Classifier-free guidance scale | |
| width: Image width | |
| height: Image height | |
| Returns: | |
| Dictionary with status and list of generated images (as base64 strings) | |
| """ | |
| import base64 | |
| import io | |
| import tempfile | |
| from pathlib import Path | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from huggingface_hub import snapshot_download, login | |
| try: | |
| # Login to HuggingFace | |
| login(token=hf_token) | |
| # Create temporary directory for model | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = Path(temp_dir) | |
| model_dir = temp_path / "model" | |
| lora_dir = temp_path / "lora" | |
| print("📥 Downloading base model...") | |
| # Download base model | |
| snapshot_download( | |
| "black-forest-labs/FLUX.1-dev", | |
| local_dir=str(model_dir), | |
| ignore_patterns=["*.pt", "*.bin"], # using safetensors | |
| token=hf_token | |
| ) | |
| print(f"📥 Downloading LoRA from {lora_repo}...") | |
| # Download LoRA | |
| snapshot_download( | |
| lora_repo, | |
| local_dir=str(lora_dir), | |
| token=hf_token | |
| ) | |
| print("🔄 Loading pipeline...") | |
| # Load the diffusion pipeline | |
| pipe = DiffusionPipeline.from_pretrained( | |
| str(model_dir), | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| # Load LoRA weights | |
| pipe.load_lora_weights(str(lora_dir)) | |
| print(f"🎨 Generating {len(prompts)} images...") | |
| generated_images = [] | |
| # Generate images for each prompt | |
| for i, prompt in enumerate(prompts): | |
| print(f" Generating image {i+1}/{len(prompts)}: {prompt[:50]}...") | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| width=width, | |
| height=height, | |
| ).images[0] | |
| # Convert PIL Image to base64 string | |
| img_buffer = io.BytesIO() | |
| image.save(img_buffer, format='PNG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| generated_images.append({ | |
| "prompt": prompt, | |
| "image": img_base64 | |
| }) | |
| print("✅ All images generated successfully!") | |
| return { | |
| "status": "success", | |
| "message": f"Generated {len(prompts)} images successfully", | |
| "lora_repo": lora_repo, | |
| "images": generated_images | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Failed to generate images: {str(e)}" | |
| } | |
| def api_generate_images(item: dict): | |
| """ | |
| Generate images using a LoRA model. | |
| Expected JSON payload: | |
| { | |
| "hf_token": "hf_...", | |
| "lora_repo": "username/my-lora", | |
| "prompts": ["prompt1", "prompt2", ...], | |
| "num_inference_steps": 20, // optional | |
| "guidance_scale": 7.5, // optional | |
| "width": 512, // optional | |
| "height": 512 // optional | |
| } | |
| """ | |
| try: | |
| # Extract required parameters | |
| hf_token = item["hf_token"] | |
| lora_repo = item["lora_repo"] | |
| prompts = item["prompts"] | |
| if not isinstance(prompts, list) or len(prompts) == 0: | |
| return { | |
| "status": "error", | |
| "message": "prompts must be a non-empty list" | |
| } | |
| # Extract optional parameters | |
| num_inference_steps = item.get("num_inference_steps", 20) | |
| guidance_scale = item.get("guidance_scale", 7.5) | |
| width = item.get("width", 512) | |
| height = item.get("height", 512) | |
| # Start generation (non-blocking) | |
| call_handle = generate_images_stateless.spawn( | |
| hf_token=hf_token, | |
| lora_repo=lora_repo, | |
| prompts=prompts, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| width=width, | |
| height=height | |
| ) | |
| job_id = call_handle.object_id | |
| return { | |
| "status": "started", | |
| "job_id": job_id, | |
| "message": "Image generation job started successfully", | |
| "lora_repo": lora_repo, | |
| "num_prompts": len(prompts) | |
| } | |
| except KeyError as e: | |
| return { | |
| "status": "error", | |
| "message": f"Missing required parameter: {e}" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "message": f"Failed to start image generation: {str(e)}" | |
| } |