stillerman commited on
Commit
75c12e8
·
1 Parent(s): 3fbcf45

mcp is alive and well!

Browse files
Files changed (5) hide show
  1. .gitignore +2 -1
  2. README.md +8 -0
  3. app.py +303 -0
  4. diffusers_lora_finetune.py +323 -2
  5. requirements.txt +2 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__
2
- .venv
 
 
1
  __pycache__
2
+ .venv
3
+ .gradio
README.md CHANGED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ## [Modal Flux Fintune Tutorial](https://modal.com/docs/examples/diffusers_lora_finetune)
4
+
5
+ ## Setup
6
+ ```
7
+ uv pip install modal
8
+ ```
app.py CHANGED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Gradio MCP server that launches modal finetune
2
+
3
+ import gradio as gr
4
+ import requests
5
+ import json
6
+ import time
7
+ from typing import Optional, Dict, Any
8
+
9
+ # Configuration - Update these URLs to match your deployed Modal app
10
+ MODAL_BASE_URL = "https://stillerman--jason-lora-flux" # Update with your actual Modal app URL
11
+ START_TRAINING_URL = f"{MODAL_BASE_URL}-api-start-training.modal.run"
12
+ JOB_STATUS_URL = f"{MODAL_BASE_URL}-api-job-status.modal.run"
13
+
14
+ def start_training(
15
+ dataset_id: str,
16
+ hf_token: str,
17
+ output_repo: str,
18
+ instance_name: Optional[str] = None,
19
+ class_name: Optional[str] = None,
20
+ max_train_steps: int = 500
21
+ ) -> tuple[str, str]:
22
+ """
23
+ Start a LoRA training job for Flux image generation model.
24
+
25
+ This function initiates a LoRA (Low-Rank Adaptation) training job on a dataset of images.
26
+ It sends a request to a Modal API endpoint to start the training process.
27
+
28
+ Parameters:
29
+ - dataset_id (str, required): The HuggingFace dataset ID containing training images, format: "username/dataset-name"
30
+ - hf_token (str, required): HuggingFace access token with read permissions, format: "hf_xxxxxxxxxxxx"
31
+ - output_repo (str, required): HuggingFace repository where trained LoRA will be uploaded, format: "username/repo-name"
32
+ - instance_name (str, optional): Name of the subject being trained (e.g., 'Fluffy', 'MyDog', 'John')
33
+ - class_name (str, optional): Class category of the subject (e.g., 'person', 'dog', 'cat', 'building')
34
+ - max_train_steps (int, optional): Number of training steps, range 100-2000, default 500
35
+
36
+ Returns:
37
+ - tuple[str, str]: (status_message, job_id)
38
+ - status_message: Human-readable status with training details or error message
39
+ - job_id: Unique identifier for the training job, empty string if failed
40
+
41
+ Example usage:
42
+ status, job_id = start_training(
43
+ dataset_id="myuser/dog-photos",
44
+ hf_token="hf_abcdef123456",
45
+ output_repo="myuser/my-dog-lora",
46
+ instance_name="Fluffy",
47
+ class_name="dog",
48
+ max_train_steps=500
49
+ )
50
+ """
51
+
52
+ if not dataset_id or not hf_token or not output_repo:
53
+ return "❌ Error: Dataset ID, HuggingFace token, and output repo are required", ""
54
+
55
+ payload = {
56
+ "dataset_id": dataset_id,
57
+ "hf_token": hf_token,
58
+ "output_repo": output_repo,
59
+ "max_train_steps": max_train_steps
60
+ }
61
+
62
+ # Add optional parameters if provided
63
+ if instance_name and instance_name.strip():
64
+ payload["instance_name"] = instance_name.strip()
65
+ if class_name and class_name.strip():
66
+ payload["class_name"] = class_name.strip()
67
+
68
+ try:
69
+ response = requests.post(
70
+ START_TRAINING_URL,
71
+ json=payload,
72
+ headers={"Content-Type": "application/json"},
73
+ timeout=30
74
+ )
75
+
76
+ if response.status_code == 200:
77
+ result = response.json()
78
+ if result.get("status") == "started":
79
+ job_id = result.get("job_id", "")
80
+ message = f"✅ Training started successfully!\n\n"
81
+ message += f"**Job ID:** `{job_id}`\n"
82
+ message += f"**Dataset:** {dataset_id}\n"
83
+ message += f"**Output Repo:** {output_repo}\n"
84
+ message += f"**Training Steps:** {max_train_steps}\n\n"
85
+ message += "Copy the Job ID to check status below."
86
+ return message, job_id
87
+ else:
88
+ return f"❌ Error: {result.get('message', 'Unknown error')}", ""
89
+ else:
90
+ return f"❌ HTTP Error {response.status_code}: {response.text}", ""
91
+
92
+ except requests.exceptions.Timeout:
93
+ return "❌ Error: Request timed out. The service might be starting up.", ""
94
+ except requests.exceptions.RequestException as e:
95
+ return f"❌ Error: Failed to connect to training service: {str(e)}", ""
96
+ except json.JSONDecodeError:
97
+ return "❌ Error: Invalid response from server", ""
98
+
99
+ def check_job_status(job_id: str) -> str:
100
+ """
101
+ Check the current status of a LoRA training job.
102
+
103
+ This function queries the Modal API to get the current status of a training job
104
+ using its unique job ID. It returns detailed information about the job progress.
105
+
106
+ Parameters:
107
+ - job_id (str, required): The unique job identifier returned from start_training function
108
+
109
+ Returns:
110
+ - str: Detailed status message containing:
111
+ - Job status (completed, running, failed, error)
112
+ - Training results if completed (dataset used, steps completed, training prompt)
113
+ - Error messages if failed
114
+ - Progress information if still running
115
+
116
+ Possible status values:
117
+ - "completed": Training finished successfully, LoRA model is ready
118
+ - "running": Training is still in progress
119
+ - "failed": Training failed due to an error
120
+ - "error": System error occurred
121
+
122
+ Example usage:
123
+ status_info = check_job_status("job_12345abcdef")
124
+ """
125
+
126
+ if not job_id or not job_id.strip():
127
+ return "❌ Error: Job ID is required"
128
+
129
+ try:
130
+ response = requests.get(
131
+ JOB_STATUS_URL,
132
+ params={"job_id": job_id.strip()},
133
+ timeout=10
134
+ )
135
+
136
+ if response.status_code == 200:
137
+ result = response.json()
138
+ status = result.get("status", "unknown")
139
+
140
+ if status == "completed":
141
+ message = "🎉 **Training Completed!**\n\n"
142
+ training_result = result.get("result", {})
143
+ if isinstance(training_result, dict):
144
+ message += f"**Status:** {training_result.get('status', 'completed')}\n"
145
+ message += f"**Message:** {training_result.get('message', 'Training finished')}\n"
146
+ if training_result.get('dataset_used'):
147
+ message += f"**Dataset Used:** {training_result['dataset_used']}\n"
148
+ if training_result.get('training_steps'):
149
+ message += f"**Training Steps:** {training_result['training_steps']}\n"
150
+ if training_result.get('training_prompt'):
151
+ message += f"**Training Prompt:** {training_result['training_prompt']}\n"
152
+ else:
153
+ message += f"**Result:** {training_result}"
154
+ return message
155
+
156
+ elif status == "running":
157
+ return f"🔄 **Training in Progress**\n\nThe training job is still running. Check back in a few minutes."
158
+
159
+ elif status == "failed":
160
+ error_msg = result.get("message", "Training failed with unknown error")
161
+ return f"❌ **Training Failed**\n\n**Error:** {error_msg}"
162
+
163
+ elif status == "error":
164
+ error_msg = result.get("message", "Unknown error occurred")
165
+ return f"❌ **Error**\n\n**Message:** {error_msg}"
166
+
167
+ else:
168
+ return f"❓ **Unknown Status**\n\n**Status:** {status}\n**Response:** {json.dumps(result, indent=2)}"
169
+
170
+ else:
171
+ return f"❌ HTTP Error {response.status_code}: {response.text}"
172
+
173
+ except requests.exceptions.Timeout:
174
+ return "❌ Error: Request timed out"
175
+ except requests.exceptions.RequestException as e:
176
+ return f"❌ Error: Failed to connect to status service: {str(e)}"
177
+ except json.JSONDecodeError:
178
+ return "❌ Error: Invalid response from server"
179
+
180
+ def check_and_update_status(job_id: str) -> str:
181
+ """
182
+ Wrapper function to check job status for Gradio interface.
183
+
184
+ This is a simple wrapper around check_job_status that provides the same functionality
185
+ but is specifically designed for use with Gradio button callbacks.
186
+
187
+ Parameters:
188
+ - job_id (str, required): The unique job identifier from training
189
+
190
+ Returns:
191
+ - str: Status message from check_job_status function
192
+
193
+ Example usage:
194
+ status = check_and_update_status("job_12345abcdef")
195
+ """
196
+ return check_job_status(job_id)
197
+
198
+ # Create simplified single-page Gradio interface
199
+ with gr.Blocks(title="FluxFoundry LoRA Training", theme=gr.themes.Soft()) as app:
200
+ gr.Markdown("""
201
+ # 🎨 FluxFoundry LoRA Training
202
+
203
+ Train custom LoRA models for Flux image generation and check training status.
204
+ """)
205
+
206
+ # Training Section
207
+ gr.Markdown("## 🚀 Start Training")
208
+
209
+ with gr.Row():
210
+ with gr.Column():
211
+ dataset_id = gr.Textbox(
212
+ label="HuggingFace Dataset ID",
213
+ placeholder="username/dataset-name",
214
+ info="The HuggingFace dataset containing your training images"
215
+ )
216
+ hf_token = gr.Textbox(
217
+ label="HuggingFace Token",
218
+ placeholder="hf_...",
219
+ type="password",
220
+ info="Your HuggingFace access token with read permissions"
221
+ )
222
+ output_repo = gr.Textbox(
223
+ label="Output Repository",
224
+ placeholder="username/my-lora-model",
225
+ info="HuggingFace repository where the trained LoRA will be uploaded"
226
+ )
227
+
228
+ with gr.Column():
229
+ instance_name = gr.Textbox(
230
+ label="Instance Name (Optional)",
231
+ placeholder="subject",
232
+ info="Name of the subject being trained (e.g., 'Fluffy', 'MyDog')"
233
+ )
234
+ class_name = gr.Textbox(
235
+ label="Class Name (Optional)",
236
+ placeholder="person",
237
+ info="Class of the subject (e.g., 'person', 'dog', 'cat')"
238
+ )
239
+ max_train_steps = gr.Slider(
240
+ minimum=100,
241
+ maximum=2000,
242
+ value=500,
243
+ step=50,
244
+ label="Max Training Steps",
245
+ info="Number of training steps (more steps = longer training)"
246
+ )
247
+
248
+ start_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
249
+
250
+ with gr.Row():
251
+ training_output = gr.Markdown(label="Training Status")
252
+ job_id_output = gr.Textbox(
253
+ label="Job ID",
254
+ placeholder="Copy this ID to check status",
255
+ interactive=False
256
+ )
257
+
258
+ start_btn.click(
259
+ fn=start_training,
260
+ inputs=[dataset_id, hf_token, output_repo, instance_name, class_name, max_train_steps],
261
+ outputs=[training_output, job_id_output]
262
+ )
263
+
264
+ # Status Section
265
+ gr.Markdown("## 📊 Check Status")
266
+
267
+ job_id_input = gr.Textbox(
268
+ label="Job ID",
269
+ placeholder="Paste your job ID here",
270
+ info="The Job ID returned when you started training"
271
+ )
272
+
273
+ with gr.Row():
274
+ status_btn = gr.Button("📊 Check Status", variant="secondary")
275
+ refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
276
+
277
+ status_output = gr.Markdown(label="Job Status")
278
+
279
+ status_btn.click(
280
+ fn=check_and_update_status,
281
+ inputs=[job_id_input],
282
+ outputs=[status_output]
283
+ )
284
+
285
+ refresh_btn.click(
286
+ fn=check_and_update_status,
287
+ inputs=[job_id_input],
288
+ outputs=[status_output]
289
+ )
290
+
291
+ if __name__ == "__main__":
292
+ print("🎨 Starting FluxFoundry Training Interface...")
293
+ print(f"📡 Modal API Base URL: {MODAL_BASE_URL}")
294
+ print("⚠️ Make sure to update the MODAL_BASE_URL in the code with your actual Modal deployment URL")
295
+
296
+ app.launch(
297
+ server_name="0.0.0.0",
298
+ server_port=7860,
299
+ share=True,
300
+ show_error=True,
301
+ mcp_server=True
302
+ )
303
+
diffusers_lora_finetune.py CHANGED
@@ -34,6 +34,7 @@
34
 
35
  from dataclasses import dataclass
36
  from pathlib import Path
 
37
 
38
  import modal
39
 
@@ -52,11 +53,12 @@ app = modal.App(name="jason-lora-flux")
52
 
53
  image = modal.Image.debian_slim(python_version="3.10").pip_install(
54
  "accelerate==0.31.0",
55
- "datasets~=2.13.0",
 
56
  "fastapi[standard]==0.115.4",
57
  "ftfy~=6.1.0",
58
  "gradio~=5.5.0",
59
- "huggingface-hub==0.26.2",
60
  "hf_transfer==0.1.8",
61
  "numpy<2",
62
  "peft==0.11.1",
@@ -184,6 +186,325 @@ def load_images(image_urls: list[str]) -> Path:
184
  return img_path
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  # ## Low-Rank Adapation (LoRA) fine-tuning for a text-to-image model
188
 
189
  # The base model we start from is trained to do a sort of "reverse [ekphrasis](https://en.wikipedia.org/wiki/Ekphrasis)":
 
34
 
35
  from dataclasses import dataclass
36
  from pathlib import Path
37
+ from typing import Optional
38
 
39
  import modal
40
 
 
53
 
54
  image = modal.Image.debian_slim(python_version="3.10").pip_install(
55
  "accelerate==0.31.0",
56
+ "datasets==3.6.0",
57
+ "pillow",
58
  "fastapi[standard]==0.115.4",
59
  "ftfy~=6.1.0",
60
  "gradio~=5.5.0",
61
+ "huggingface-hub==0.32.4",
62
  "hf_transfer==0.1.8",
63
  "numpy<2",
64
  "peft==0.11.1",
 
186
  return img_path
187
 
188
 
189
+ def load_images_from_hf_dataset(dataset_id: str, hf_token: str) -> Path:
190
+ """Load images from a HuggingFace dataset."""
191
+ import PIL.Image
192
+ from datasets import load_dataset
193
+
194
+ img_path = Path("/img")
195
+ img_path.mkdir(parents=True, exist_ok=True)
196
+
197
+ # Load dataset from HuggingFace
198
+ dataset = load_dataset(dataset_id, token=hf_token, split="train")
199
+
200
+ for ii, example in enumerate(dataset):
201
+ # Assume the dataset has an 'image' column
202
+ if 'image' in example:
203
+ image = example['image']
204
+ if isinstance(image, PIL.Image.Image):
205
+ image.save(img_path / f"{ii}.png")
206
+ else:
207
+ # Handle other image formats
208
+ pil_image = PIL.Image.open(image)
209
+ pil_image.save(img_path / f"{ii}.png")
210
+ else:
211
+ print(f"Warning: No 'image' field found in dataset example {ii}")
212
+
213
+ print(f"{len(dataset)} images loaded from HuggingFace dataset")
214
+ return img_path
215
+
216
+
217
+ # ## Stateless API Training Function
218
+
219
+ @dataclass
220
+ class APITrainConfig:
221
+ """Configuration for the API training function."""
222
+
223
+ # Basic model info
224
+ model_name: str = "black-forest-labs/FLUX.1-dev"
225
+
226
+ # Training prompt components
227
+ instance_name: str = "subject"
228
+ class_name: str = "person"
229
+ prefix: str = "a photo of"
230
+ postfix: str = ""
231
+
232
+ # Training hyperparameters
233
+ resolution: int = 512
234
+ train_batch_size: int = 3
235
+ rank: int = 16 # lora rank
236
+ gradient_accumulation_steps: int = 1
237
+ learning_rate: float = 4e-4
238
+ lr_scheduler: str = "constant"
239
+ lr_warmup_steps: int = 0
240
+ max_train_steps: int = 500
241
+ checkpointing_steps: int = 1000
242
+ seed: int = 117
243
+
244
+
245
+ @app.function(
246
+ image=image,
247
+ gpu="A100-80GB", # fine-tuning is VRAM-heavy and requires a high-VRAM GPU
248
+ timeout=3600, # 60 minutes
249
+ )
250
+ def train_lora_stateless(
251
+ dataset_id: str,
252
+ hf_token: str,
253
+ output_repo: str,
254
+ instance_name: Optional[str] = None,
255
+ class_name: Optional[str] = None,
256
+ max_train_steps: int = 500,
257
+ ):
258
+ """
259
+ Stateless LoRA training function that reads from HF dataset and uploads to HF repo.
260
+
261
+ Args:
262
+ dataset_id: HuggingFace dataset ID (e.g., "username/dataset-name")
263
+ hf_token: HuggingFace API token
264
+ output_repo: HuggingFace repository to upload the trained LoRA to
265
+ instance_name: Name of the subject (optional, defaults to "subject")
266
+ class_name: Class of the subject (optional, defaults to "person")
267
+ max_train_steps: Number of training steps
268
+ """
269
+ import subprocess
270
+ import tempfile
271
+ from pathlib import Path
272
+
273
+ import torch
274
+ from accelerate.utils import write_basic_config
275
+ from diffusers import DiffusionPipeline
276
+ from huggingface_hub import snapshot_download, upload_folder, login, create_repo
277
+
278
+ # Login to HuggingFace
279
+ login(token=hf_token)
280
+
281
+ # Create temporary directories
282
+ with tempfile.TemporaryDirectory() as temp_dir:
283
+ temp_path = Path(temp_dir)
284
+ model_dir = temp_path / "model"
285
+ output_dir = temp_path / "output"
286
+
287
+ # Download base model
288
+ print("📥 Downloading base model...")
289
+ snapshot_download(
290
+ "black-forest-labs/FLUX.1-dev",
291
+ local_dir=str(model_dir),
292
+ ignore_patterns=["*.pt", "*.bin"], # using safetensors
293
+ token=hf_token
294
+ )
295
+
296
+ # Load and validate model
297
+ DiffusionPipeline.from_pretrained(str(model_dir), torch_dtype=torch.bfloat16)
298
+ print("✅ Base model loaded successfully")
299
+
300
+ # Load training images from HF dataset
301
+ print(f"📥 Loading images from dataset: {dataset_id}")
302
+ img_path = load_images_from_hf_dataset(dataset_id, hf_token)
303
+
304
+ # Set up training configuration
305
+ config = APITrainConfig(
306
+ instance_name=instance_name or "subject",
307
+ class_name=class_name or "person",
308
+ max_train_steps=max_train_steps
309
+ )
310
+
311
+ # Set up hugging face accelerate library for fast training
312
+ write_basic_config(mixed_precision="bf16")
313
+
314
+ # Define the training prompt
315
+ instance_phrase = f"{config.instance_name} the {config.class_name}"
316
+ prompt = f"{config.prefix} {instance_phrase} {config.postfix}".strip()
317
+
318
+ print(f"🎯 Training prompt: {prompt}")
319
+ print(f"🚀 Starting training for {max_train_steps} steps...")
320
+
321
+ # Execute training subprocess
322
+ def _exec_subprocess(cmd: list[str]):
323
+ """Executes subprocess and prints log to terminal while subprocess is running."""
324
+ process = subprocess.Popen(
325
+ cmd,
326
+ stdout=subprocess.PIPE,
327
+ stderr=subprocess.STDOUT,
328
+ )
329
+ with process.stdout as pipe:
330
+ for line in iter(pipe.readline, b""):
331
+ line_str = line.decode()
332
+ print(f"{line_str}", end="")
333
+
334
+ if exitcode := process.wait() != 0:
335
+ raise subprocess.CalledProcessError(exitcode, "\n".join(cmd))
336
+
337
+ # Run training
338
+ _exec_subprocess([
339
+ "accelerate",
340
+ "launch",
341
+ "examples/dreambooth/train_dreambooth_lora_flux.py",
342
+ "--mixed_precision=bf16",
343
+ f"--pretrained_model_name_or_path={model_dir}",
344
+ f"--instance_data_dir={img_path}",
345
+ f"--output_dir={output_dir}",
346
+ f"--instance_prompt={prompt}",
347
+ f"--resolution={config.resolution}",
348
+ f"--train_batch_size={config.train_batch_size}",
349
+ f"--gradient_accumulation_steps={config.gradient_accumulation_steps}",
350
+ f"--learning_rate={config.learning_rate}",
351
+ f"--lr_scheduler={config.lr_scheduler}",
352
+ f"--lr_warmup_steps={config.lr_warmup_steps}",
353
+ f"--max_train_steps={config.max_train_steps}",
354
+ f"--checkpointing_steps={config.checkpointing_steps}",
355
+ f"--seed={config.seed}",
356
+ ])
357
+
358
+ print("✅ Training completed!")
359
+
360
+ # Upload trained LoRA to HuggingFace repository
361
+
362
+ print(f"📤 Uploading LoRA to repository: {output_repo}")
363
+
364
+ # Create repository if it doesn't exist
365
+ create_repo(
366
+ repo_id=output_repo,
367
+ repo_type="model",
368
+ token=hf_token
369
+ )
370
+
371
+ # print contents of output_dir
372
+ print(f"Contents of {output_dir}:")
373
+ for file in output_dir.iterdir():
374
+ print(file)
375
+
376
+ upload_folder(
377
+ folder_path=str(output_dir),
378
+ repo_id=output_repo,
379
+ repo_type="model",
380
+ token=hf_token,
381
+ commit_message=f"Add LoRA trained on {dataset_id}",
382
+ )
383
+
384
+ print(f"🎉 Successfully uploaded LoRA to {output_repo}")
385
+
386
+ return {
387
+ "status": "success",
388
+ "message": f"LoRA training completed and uploaded to {output_repo}",
389
+ "dataset_used": dataset_id,
390
+ "training_steps": max_train_steps,
391
+ "training_prompt": prompt
392
+ }
393
+
394
+
395
+ # ## API Endpoints with Job ID System
396
+
397
+ @app.function(
398
+ image=image,
399
+ keep_warm=1, # Keep one container warm for faster response
400
+ )
401
+ @modal.fastapi_endpoint(method="POST")
402
+ def api_start_training(item: dict):
403
+ """
404
+ Start LoRA training and return a job ID.
405
+
406
+ Expected JSON payload:
407
+ {
408
+ "dataset_id": "username/dataset-name",
409
+ "hf_token": "hf_...",
410
+ "output_repo": "username/output-repo",
411
+ "instance_name": "optional_subject_name",
412
+ "class_name": "optional_class_name",
413
+ "max_train_steps": 500
414
+ }
415
+ """
416
+ try:
417
+ # Extract required parameters
418
+ dataset_id = item["dataset_id"]
419
+ hf_token = item["hf_token"]
420
+ output_repo = item["output_repo"]
421
+
422
+ # Extract optional parameters
423
+ instance_name = item.get("instance_name")
424
+ class_name = item.get("class_name")
425
+ max_train_steps = item.get("max_train_steps", 500)
426
+
427
+ # Start training (non-blocking)
428
+ call_handle = train_lora_stateless.spawn(
429
+ dataset_id=dataset_id,
430
+ hf_token=hf_token,
431
+ output_repo=output_repo,
432
+ instance_name=instance_name,
433
+ class_name=class_name,
434
+ max_train_steps=max_train_steps
435
+ )
436
+
437
+ job_id = call_handle.object_id
438
+
439
+ return {
440
+ "status": "started",
441
+ "job_id": job_id,
442
+ "message": "Training job started successfully",
443
+ "dataset_id": dataset_id,
444
+ "output_repo": output_repo,
445
+ "max_train_steps": max_train_steps
446
+ }
447
+
448
+ except KeyError as e:
449
+ return {
450
+ "status": "error",
451
+ "message": f"Missing required parameter: {e}"
452
+ }
453
+ except Exception as e:
454
+ return {
455
+ "status": "error",
456
+ "message": f"Failed to start training: {str(e)}"
457
+ }
458
+
459
+
460
+ @app.function(
461
+ image=image,
462
+ keep_warm=1,
463
+ )
464
+ @modal.fastapi_endpoint(method="GET")
465
+ def api_job_status(job_id: str):
466
+ """
467
+ Check the status of a training job.
468
+ Pass job_id as a query parameter: /job_status?job_id=xyz
469
+ """
470
+ try:
471
+ from modal.functions import FunctionCall
472
+
473
+ # Get the function call handle
474
+ call_handle = FunctionCall.from_id(job_id)
475
+
476
+ if call_handle is None:
477
+ return {
478
+ "status": "error",
479
+ "message": "Job not found"
480
+ }
481
+
482
+ # Check if the job is finished
483
+ try:
484
+ result = call_handle.get(timeout=0) # Non-blocking check
485
+ return {
486
+ "status": "completed",
487
+ "result": result
488
+ }
489
+ except TimeoutError:
490
+ return {
491
+ "status": "running",
492
+ "message": "Job is still running"
493
+ }
494
+ except Exception as e:
495
+ return {
496
+ "status": "failed",
497
+ "message": f"Job failed: {str(e)}"
498
+ }
499
+
500
+ except Exception as e:
501
+ return {
502
+ "status": "error",
503
+ "message": f"Error checking job status: {str(e)}"
504
+ }
505
+
506
+
507
+
508
  # ## Low-Rank Adapation (LoRA) fine-tuning for a text-to-image model
509
 
510
  # The base model we start from is trained to do a sort of "reverse [ekphrasis](https://en.wikipedia.org/wiki/Ekphrasis)":
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio>=4.0.0
2
+ requests>=2.25.0