multimodalart HF Staff commited on
Commit
4400ddc
·
verified ·
1 Parent(s): f1ba847

Upload 103 files

Browse files
ui/src/app/api/datasets/create/route.tsx CHANGED
@@ -3,22 +3,80 @@ import fs from 'fs';
3
  import path from 'path';
4
  import { getDatasetsRoot } from '@/server/settings';
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  export async function POST(request: Request) {
7
  try {
8
  const body = await request.json();
9
- let { name } = body;
10
- // clean name by making lower case, removing special characters, and replacing spaces with underscores
11
- name = name.toLowerCase().replace(/[^a-z0-9]+/g, '_');
12
-
13
- let datasetsPath = await getDatasetsRoot();
14
- let datasetPath = path.join(datasetsPath, name);
15
 
16
- // if folder doesnt exist, create it
17
- if (!fs.existsSync(datasetPath)) {
18
- fs.mkdirSync(datasetPath, { recursive: true });
19
  }
20
 
21
- return NextResponse.json({ success: true, name: name, path: datasetPath });
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  } catch (error: any) {
23
  console.error('Dataset create error:', error);
24
  return NextResponse.json({ error: error?.message || 'Failed to create dataset' }, { status: 500 });
 
3
  import path from 'path';
4
  import { getDatasetsRoot } from '@/server/settings';
5
 
6
+ const sanitizeSegment = (value: string): string => {
7
+ if (!value || typeof value !== 'string') {
8
+ return '';
9
+ }
10
+ return value
11
+ .toLowerCase()
12
+ .replace(/[^a-z0-9]+/g, '_')
13
+ .replace(/^_+|_+$/g, '');
14
+ };
15
+
16
+ const ensureDirectory = (dirPath: string) => {
17
+ if (!fs.existsSync(dirPath)) {
18
+ fs.mkdirSync(dirPath, { recursive: true });
19
+ }
20
+ };
21
+
22
+ const resolveDatasetName = (rootPath: string, desiredName: string, namespace?: string | null) => {
23
+ const baseName = sanitizeSegment(desiredName) || 'dataset';
24
+ const namespaceSuffix = sanitizeSegment(namespace || '');
25
+
26
+ const datasetExists = (candidate: string) => fs.existsSync(path.join(rootPath, candidate));
27
+
28
+ if (!datasetExists(baseName)) {
29
+ return { name: baseName, path: path.join(rootPath, baseName) };
30
+ }
31
+
32
+ if (namespaceSuffix) {
33
+ let candidate = sanitizeSegment(`${baseName}_${namespaceSuffix}`) || `${baseName}_${namespaceSuffix}`;
34
+ let attempts = 0;
35
+ while (datasetExists(candidate)) {
36
+ attempts += 1;
37
+ if (attempts > 50) {
38
+ throw new Error('Unable to allocate unique dataset name');
39
+ }
40
+ candidate = sanitizeSegment(`${candidate}_${namespaceSuffix}`) || `${candidate}_${namespaceSuffix}`;
41
+ }
42
+ return { name: candidate, path: path.join(rootPath, candidate) };
43
+ }
44
+
45
+ let counter = 2;
46
+ while (counter < 1000) {
47
+ const candidate = sanitizeSegment(`${baseName}_${counter}`) || `${baseName}_${counter}`;
48
+ if (!datasetExists(candidate)) {
49
+ return { name: candidate, path: path.join(rootPath, candidate) };
50
+ }
51
+ counter += 1;
52
+ }
53
+
54
+ throw new Error('Unable to allocate unique dataset name');
55
+ };
56
+
57
  export async function POST(request: Request) {
58
  try {
59
  const body = await request.json();
60
+ const { name, namespace } = body ?? {};
 
 
 
 
 
61
 
62
+ if (!name || typeof name !== 'string') {
63
+ throw new Error('Dataset name is required');
 
64
  }
65
 
66
+ const datasetsPath = await getDatasetsRoot();
67
+ const { name: resolvedName, path: datasetPath } = resolveDatasetName(
68
+ datasetsPath,
69
+ name,
70
+ typeof namespace === 'string' ? namespace : null,
71
+ );
72
+
73
+ ensureDirectory(datasetPath);
74
+
75
+ return NextResponse.json({
76
+ success: true,
77
+ name: resolvedName,
78
+ path: datasetPath,
79
+ });
80
  } catch (error: any) {
81
  console.error('Dataset create error:', error);
82
  return NextResponse.json({ error: error?.message || 'Failed to create dataset' }, { status: 500 });
ui/src/app/api/hf-jobs/route.ts CHANGED
@@ -401,14 +401,13 @@ def upload_results(output_path: str, model_name: str, namespace: str, token: str
401
  import tempfile
402
  import shutil
403
  import glob
404
- import re
405
- import yaml
406
  from datetime import datetime
407
  from huggingface_hub import create_repo, upload_file, HfApi
408
-
 
409
  try:
410
  repo_id = f"{namespace}/{model_name}"
411
-
412
  # Create repository
413
  create_repo(repo_id=repo_id, token=token, exist_ok=True)
414
 
@@ -453,30 +452,86 @@ def upload_results(output_path: str, model_name: str, namespace: str, token: str
453
  uploaded_files.append(filename)
454
  config_files_uploaded.append(filename)
455
 
456
- # 2. Handle sample images
457
- samples_uploaded = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  samples_dir = os.path.join(output_path, "samples")
459
- if os.path.isdir(samples_dir):
 
 
 
 
460
  print("Uploading sample images...")
461
- # Create samples directory in repo
462
- for filename in os.listdir(samples_dir):
463
- if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
464
- file_path = os.path.join(samples_dir, filename)
465
- repo_path = f"samples/{filename}"
466
- api.upload_file(
467
- path_or_fileobj=file_path,
468
- path_in_repo=repo_path,
469
- repo_id=repo_id,
470
- token=token
471
- )
472
- samples_uploaded.append(repo_path)
473
-
474
  # 3. Generate and upload README.md
475
  readme_content = generate_model_card_readme(
476
  repo_id=repo_id,
477
  config=config,
478
  model_name=model_name,
479
- samples_dir=samples_dir if os.path.isdir(samples_dir) else None,
480
  uploaded_files=uploaded_files
481
  )
482
 
@@ -500,12 +555,11 @@ def upload_results(output_path: str, model_name: str, namespace: str, token: str
500
  print(f"Failed to upload model: {e}")
501
  raise e
502
 
503
- def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str:
504
  """Generate README.md content for the model card based on AI Toolkit's implementation"""
505
- import re
506
  import yaml
507
  import os
508
-
509
  try:
510
  # Extract configuration details
511
  process_config = config.get("config", {}).get("process", [{}])[0]
@@ -545,40 +599,27 @@ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samp
545
  # Add LoRA-specific tags
546
  tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
547
 
548
- # Generate widgets from sample images and prompts
 
 
549
  widgets = []
550
- if samples_dir and os.path.isdir(samples_dir):
551
- sample_prompts = sample_config.get("samples", [])
552
- if not sample_prompts:
553
- # Fallback to old format
554
- sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])]
555
-
556
- # Get sample image files
557
- sample_files = []
558
- if os.path.isdir(samples_dir):
559
- for filename in os.listdir(samples_dir):
560
- if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
561
- # Parse filename pattern: timestamp__steps_index.jpg
562
- match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
563
- if match:
564
- steps, index = int(match.group(1)), int(match.group(2))
565
- # Only use samples from final training step
566
- final_steps = train_config.get("steps", 1000)
567
- if steps == final_steps:
568
- sample_files.append((index, f"samples/{filename}"))
569
-
570
- # Sort by index and create widgets
571
- sample_files.sort(key=lambda x: x[0])
572
-
573
- for i, prompt_obj in enumerate(sample_prompts):
574
- prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj)
575
- if i < len(sample_files):
576
- _, image_path = sample_files[i]
577
- widgets.append({
578
- "text": prompt,
579
- "output": {"url": image_path}
580
- })
581
-
582
  # Determine torch dtype based on model
583
  dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
584
 
@@ -598,6 +639,16 @@ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samp
598
 
599
  if widgets:
600
  frontmatter["widget"] = widgets
 
 
 
 
 
 
 
 
 
 
601
 
602
  if trigger_word:
603
  frontmatter["instance_prompt"] = trigger_word
@@ -623,7 +674,7 @@ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samp
623
 
624
  Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
625
 
626
- <Gallery />
627
 
628
  ## Trigger words
629
 
 
401
  import tempfile
402
  import shutil
403
  import glob
 
 
404
  from datetime import datetime
405
  from huggingface_hub import create_repo, upload_file, HfApi
406
+ from collections import deque
407
+
408
  try:
409
  repo_id = f"{namespace}/{model_name}"
410
+
411
  # Create repository
412
  create_repo(repo_id=repo_id, token=token, exist_ok=True)
413
 
 
452
  uploaded_files.append(filename)
453
  config_files_uploaded.append(filename)
454
 
455
+ def prepare_sample_metadata(samples_directory: str, sample_conf: dict):
456
+ if not samples_directory or not os.path.isdir(samples_directory):
457
+ return [], []
458
+
459
+ allowed_ext = {'.jpg', '.jpeg', '.png', '.webp'}
460
+ image_records = []
461
+ for root, _, files in os.walk(samples_directory):
462
+ for filename in files:
463
+ ext = os.path.splitext(filename)[1].lower()
464
+ if ext not in allowed_ext:
465
+ continue
466
+ abs_path = os.path.join(root, filename)
467
+ try:
468
+ mtime = os.path.getmtime(abs_path)
469
+ except Exception:
470
+ mtime = 0
471
+ image_records.append((abs_path, mtime))
472
+
473
+ if not image_records:
474
+ return [], []
475
+
476
+ image_records.sort(key=lambda item: (-item[1], item[0]))
477
+ image_queue = deque(image_records)
478
+
479
+ samples_list = sample_conf.get("samples", []) if sample_conf else []
480
+ if not samples_list:
481
+ legacy = sample_conf.get("prompts", []) if sample_conf else []
482
+ samples_list = [{"prompt": prompt} for prompt in legacy if prompt]
483
+
484
+ curated_samples = []
485
+ for sample in samples_list:
486
+ prompt = None
487
+ if isinstance(sample, dict):
488
+ prompt = sample.get("prompt")
489
+ elif isinstance(sample, str):
490
+ prompt = sample
491
+
492
+ if not prompt:
493
+ continue
494
+
495
+ if not image_queue:
496
+ break
497
+
498
+ image_path, _ = image_queue.popleft()
499
+ repo_rel_path = f"images/{os.path.basename(image_path)}"
500
+ curated_samples.append({
501
+ "prompt": prompt,
502
+ "local_path": image_path,
503
+ "repo_path": repo_rel_path,
504
+ })
505
+
506
+ all_files = [record[0] for record in image_records]
507
+ return curated_samples, all_files
508
+
509
  samples_dir = os.path.join(output_path, "samples")
510
+ sample_config = config.get("config", {}).get("process", [{}])[0].get("sample", {})
511
+ curated_samples, sample_files = prepare_sample_metadata(samples_dir, sample_config)
512
+
513
+ samples_uploaded = []
514
+ if sample_files:
515
  print("Uploading sample images...")
516
+ for file_path in sample_files:
517
+ if not os.path.isfile(file_path):
518
+ continue
519
+ filename = os.path.basename(file_path)
520
+ repo_path = f"images/{filename}"
521
+ api.upload_file(
522
+ path_or_fileobj=file_path,
523
+ path_in_repo=repo_path,
524
+ repo_id=repo_id,
525
+ token=token
526
+ )
527
+ samples_uploaded.append(repo_path)
528
+
529
  # 3. Generate and upload README.md
530
  readme_content = generate_model_card_readme(
531
  repo_id=repo_id,
532
  config=config,
533
  model_name=model_name,
534
+ curated_samples=curated_samples,
535
  uploaded_files=uploaded_files
536
  )
537
 
 
555
  print(f"Failed to upload model: {e}")
556
  raise e
557
 
558
+ def generate_model_card_readme(repo_id: str, config: dict, model_name: str, curated_samples: list = None, uploaded_files: list = None) -> str:
559
  """Generate README.md content for the model card based on AI Toolkit's implementation"""
 
560
  import yaml
561
  import os
562
+
563
  try:
564
  # Extract configuration details
565
  process_config = config.get("config", {}).get("process", [{}])[0]
 
599
  # Add LoRA-specific tags
600
  tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
601
 
602
+ # Generate widgets and gallery section from sample images
603
+ curated_samples = curated_samples or []
604
+
605
  widgets = []
606
+ prompt_bullets = []
607
+ for sample in curated_samples:
608
+ prompt_text = str(sample.get("prompt", "")).strip()
609
+ repo_path = sample.get("repo_path")
610
+ if not prompt_text or not repo_path:
611
+ continue
612
+ widgets.append({
613
+ "text": prompt_text,
614
+ "output": {"url": repo_path}
615
+ })
616
+ prompt_md = prompt_text.replace("`", "\`")
617
+ prompt_bullets.append(f"- `{prompt_md}`")
618
+
619
+ gallery_section = "<Gallery />\n\n"
620
+ if prompt_bullets:
621
+ gallery_section += "### Prompts\n\n" + "\n".join(prompt_bullets) + "\n\n"
622
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  # Determine torch dtype based on model
624
  dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
625
 
 
639
 
640
  if widgets:
641
  frontmatter["widget"] = widgets
642
+
643
+ inference_params = {}
644
+ sample_width = sample_config.get("width") if isinstance(sample_config, dict) else None
645
+ sample_height = sample_config.get("height") if isinstance(sample_config, dict) else None
646
+ if sample_width:
647
+ inference_params["width"] = sample_width
648
+ if sample_height:
649
+ inference_params["height"] = sample_height
650
+ if inference_params:
651
+ frontmatter["inference"] = {"parameters": inference_params}
652
 
653
  if trigger_word:
654
  frontmatter["instance_prompt"] = trigger_word
 
674
 
675
  Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
676
 
677
+ {gallery_section}
678
 
679
  ## Trigger words
680
 
ui/src/app/datasets/page.tsx CHANGED
@@ -22,7 +22,7 @@ export default function Datasets() {
22
  const { datasets, status, refreshDatasets } = useDatasetList();
23
  const [newDatasetName, setNewDatasetName] = useState('');
24
  const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
25
- const { status: authStatus } = useAuth();
26
  const isAuthenticated = authStatus === 'authenticated';
27
 
28
  // Transform datasets array into rows with objects
@@ -85,7 +85,9 @@ export default function Datasets() {
85
  return;
86
  }
87
  try {
88
- const data = await apiClient.post('/api/datasets/create', { name: newDatasetName }).then(res => res.data);
 
 
89
  console.log('New dataset created:', data);
90
  if (usingBrowserDb && data?.name) {
91
  addUserDataset(data.name, data?.path || '');
@@ -117,7 +119,9 @@ export default function Datasets() {
117
  return;
118
  }
119
  try {
120
- const data = await apiClient.post('/api/datasets/create', { name }).then(res => res.data);
 
 
121
  console.log('New dataset created:', data);
122
  if (usingBrowserDb && data?.name) {
123
  addUserDataset(data.name, data?.path || '');
 
22
  const { datasets, status, refreshDatasets } = useDatasetList();
23
  const [newDatasetName, setNewDatasetName] = useState('');
24
  const [isNewDatasetModalOpen, setIsNewDatasetModalOpen] = useState(false);
25
+ const { status: authStatus, namespace } = useAuth();
26
  const isAuthenticated = authStatus === 'authenticated';
27
 
28
  // Transform datasets array into rows with objects
 
85
  return;
86
  }
87
  try {
88
+ const data = await apiClient
89
+ .post('/api/datasets/create', { name: newDatasetName, namespace })
90
+ .then(res => res.data);
91
  console.log('New dataset created:', data);
92
  if (usingBrowserDb && data?.name) {
93
  addUserDataset(data.name, data?.path || '');
 
119
  return;
120
  }
121
  try {
122
+ const data = await apiClient
123
+ .post('/api/datasets/create', { name, namespace })
124
+ .then(res => res.data);
125
  console.log('New dataset created:', data);
126
  if (usingBrowserDb && data?.name) {
127
  addUserDataset(data.name, data?.path || '');
ui/src/app/jobs/new/jobConfig.ts CHANGED
@@ -74,7 +74,7 @@ export const defaultJobConfig: JobConfig = {
74
  use_ema: false,
75
  ema_decay: 0.99,
76
  },
77
- skip_first_sample: false,
78
  disable_sampling: false,
79
  dtype: 'bf16',
80
  diff_output_preservation: false,
@@ -94,7 +94,7 @@ export const defaultJobConfig: JobConfig = {
94
  },
95
  sample: {
96
  sampler: 'flowmatch',
97
- sample_every: 250,
98
  width: 1024,
99
  height: 1024,
100
  samples: [
 
74
  use_ema: false,
75
  ema_decay: 0.99,
76
  },
77
+ skip_first_sample: true,
78
  disable_sampling: false,
79
  dtype: 'bf16',
80
  diff_output_preservation: false,
 
94
  },
95
  sample: {
96
  sampler: 'flowmatch',
97
+ sample_every: 1500,
98
  width: 1024,
99
  height: 1024,
100
  samples: [
ui/src/app/jobs/new/page.tsx CHANGED
@@ -31,7 +31,7 @@ export default function TrainingForm() {
31
  const router = useRouter();
32
  const searchParams = useSearchParams();
33
  const runId = searchParams.get('id');
34
- const { status: authStatus } = useAuth();
35
  const isAuthenticated = authStatus === 'authenticated';
36
  const [gpuIDs, setGpuIDs] = useState<string | null>(null);
37
  const { settings, isSettingsLoaded } = useSettings();
@@ -67,7 +67,7 @@ export default function TrainingForm() {
67
  } else {
68
  try {
69
  const response = await apiClient
70
- .post('/api/datasets/create', { name })
71
  .then(res => res.data);
72
  if (response?.path) {
73
  datasetPath = response.path;
 
31
  const router = useRouter();
32
  const searchParams = useSearchParams();
33
  const runId = searchParams.get('id');
34
+ const { status: authStatus, namespace: authNamespace } = useAuth();
35
  const isAuthenticated = authStatus === 'authenticated';
36
  const [gpuIDs, setGpuIDs] = useState<string | null>(null);
37
  const { settings, isSettingsLoaded } = useSettings();
 
67
  } else {
68
  try {
69
  const response = await apiClient
70
+ .post('/api/datasets/create', { name, namespace: authNamespace })
71
  .then(res => res.data);
72
  if (response?.path) {
73
  datasetPath = response.path;