multimodalart HF Staff commited on
Commit
5675b78
·
verified ·
1 Parent(s): 63037c8

attempt local dataset fix

Browse files
Files changed (1) hide show
  1. ui/src/app/api/hf-jobs/route.ts +108 -74
ui/src/app/api/hf-jobs/route.ts CHANGED
@@ -162,104 +162,138 @@ def setup_ai_toolkit():
162
  sys.path.insert(0, os.path.abspath(repo_dir))
163
  return repo_dir
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def download_dataset(dataset_repo: str, local_path: str):
166
  """Download dataset from HF Hub as files"""
167
  print(f"Downloading dataset from {dataset_repo}...")
168
-
169
- # Create local dataset directory
170
  os.makedirs(local_path, exist_ok=True)
171
-
172
- # Use snapshot_download to get the dataset files directly
173
- from huggingface_hub import snapshot_download
174
-
 
 
 
 
 
175
  try:
176
- # First try to download as a structured dataset
177
- dataset = load_dataset(dataset_repo, split="train")
178
-
179
- # Download images and captions from structured dataset
180
  for i, item in enumerate(dataset):
181
- # Save image
182
  if "image" in item:
183
  image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
184
  image = item["image"]
185
-
186
- # Convert RGBA to RGB if necessary (for JPEG compatibility)
187
  if image.mode == 'RGBA':
188
- # Create a white background and paste the RGBA image on it
189
  background = Image.new('RGB', image.size, (255, 255, 255))
190
- background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
191
  image = background
192
  elif image.mode not in ['RGB', 'L']:
193
- # Convert any other mode to RGB
194
  image = image.convert('RGB')
195
-
196
  image.save(image_path, 'JPEG')
197
-
198
- # Save caption
199
  if "text" in item:
200
  caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
201
  with open(caption_path, "w", encoding="utf-8") as f:
202
  f.write(item["text"])
203
-
204
  print(f"Downloaded {len(dataset)} items to {local_path}")
205
-
206
  except Exception as e:
207
  print(f"Failed to load as structured dataset: {e}")
208
  print("Attempting to download raw files...")
209
-
210
- # Download the dataset repository as files
211
- temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset")
212
-
213
- # Copy all image and text files to the local path
214
- import glob
215
- import shutil
216
-
217
  print(f"Downloaded repo to: {temp_repo_path}")
218
  print(f"Contents: {os.listdir(temp_repo_path)}")
219
-
220
- # Find all image files
221
- image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
222
- image_files = []
223
- for ext in image_extensions:
224
- pattern = os.path.join(temp_repo_path, "**", ext)
225
- found_files = glob.glob(pattern, recursive=True)
226
- image_files.extend(found_files)
227
- print(f"Pattern {pattern} found {len(found_files)} files")
228
-
229
- # Find all text files
230
- text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True)
231
-
232
- print(f"Found {len(image_files)} image files and {len(text_files)} text files")
233
-
234
- # Copy image files
235
- for i, img_file in enumerate(image_files):
236
- dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
237
-
238
- # Load and convert image if needed
239
- try:
240
- with Image.open(img_file) as image:
241
- if image.mode == 'RGBA':
242
- background = Image.new('RGB', image.size, (255, 255, 255))
243
- background.paste(image, mask=image.split()[-1])
244
- image = background
245
- elif image.mode not in ['RGB', 'L']:
246
- image = image.convert('RGB')
247
-
248
- image.save(dest_path, 'JPEG')
249
- except Exception as img_error:
250
- print(f"Error processing image {img_file}: {img_error}")
251
- continue
252
-
253
- # Copy text files (captions)
254
- for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images
255
- dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
256
- try:
257
- shutil.copy2(txt_file, dest_path)
258
- except Exception as txt_error:
259
- print(f"Error copying text file {txt_file}: {txt_error}")
260
- continue
261
-
262
- print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}")
263
 
264
  def create_config(dataset_path: str, output_path: str):
265
  """Create training configuration"""
@@ -759,4 +793,4 @@ async function checkHFJobStatus(token: string, jobId: string): Promise<any> {
759
  reject(new Error(`Process error: ${err.message}`));
760
  });
761
  });
762
- }
 
162
  sys.path.insert(0, os.path.abspath(repo_dir))
163
  return repo_dir
164
 
165
+ def find_local_dataset_source(dataset_repo: str):
166
+ if not dataset_repo:
167
+ return None
168
+
169
+ repo_stripped = dataset_repo.strip()
170
+ candidates = []
171
+
172
+ if os.path.isabs(repo_stripped):
173
+ candidates.append(repo_stripped)
174
+ else:
175
+ candidates.append(repo_stripped)
176
+ candidates.append(os.path.abspath(repo_stripped))
177
+
178
+ normalized = normalize_repo_id(repo_stripped)
179
+ if normalized:
180
+ candidates.append(os.path.join("/datasets", normalized))
181
+
182
+ if repo_stripped.startswith("/datasets/") and repo_stripped not in candidates:
183
+ candidates.append(repo_stripped)
184
+
185
+ seen = set()
186
+ for candidate in candidates:
187
+ if not candidate or candidate in seen:
188
+ continue
189
+ seen.add(candidate)
190
+ if os.path.exists(candidate):
191
+ return candidate
192
+
193
+ return None
194
+
195
+
196
+ def normalize_repo_id(dataset_repo: str) -> str:
197
+ repo_id = dataset_repo.strip()
198
+ if repo_id.startswith("/datasets/"):
199
+ repo_id = repo_id[len("/datasets/"):]
200
+ elif repo_id.startswith("datasets/"):
201
+ repo_id = repo_id[len("datasets/"):]
202
+ return repo_id.strip("/")
203
+
204
+
205
+ def copy_dataset_files(source_dir: str, local_path: str):
206
+ print(f"Collecting data files from {source_dir}")
207
+
208
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
209
+ image_files = []
210
+ for ext in image_extensions:
211
+ pattern = os.path.join(source_dir, "**", ext)
212
+ found_files = glob.glob(pattern, recursive=True)
213
+ image_files.extend(found_files)
214
+ print(f"Pattern {pattern} found {len(found_files)} files")
215
+
216
+ text_files = glob.glob(os.path.join(source_dir, "**", "*.txt"), recursive=True)
217
+
218
+ print(f"Found {len(image_files)} image files and {len(text_files)} text files")
219
+
220
+ for i, img_file in enumerate(image_files):
221
+ dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
222
+
223
+ try:
224
+ with Image.open(img_file) as image:
225
+ if image.mode == 'RGBA':
226
+ background = Image.new('RGB', image.size, (255, 255, 255))
227
+ background.paste(image, mask=image.split()[-1])
228
+ image = background
229
+ elif image.mode not in ['RGB', 'L']:
230
+ image = image.convert('RGB')
231
+
232
+ image.save(dest_path, 'JPEG')
233
+ except Exception as img_error:
234
+ print(f"Error processing image {img_file}: {img_error}")
235
+ continue
236
+
237
+ captions_to_copy = min(len(text_files), len(image_files))
238
+ for i, txt_file in enumerate(text_files[:captions_to_copy]):
239
+ dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
240
+ try:
241
+ shutil.copy2(txt_file, dest_path)
242
+ except Exception as txt_error:
243
+ print(f"Error copying text file {txt_file}: {txt_error}")
244
+ continue
245
+
246
+ print(f"Prepared {len(image_files)} images and {captions_to_copy} captions in {local_path}")
247
+
248
+
249
  def download_dataset(dataset_repo: str, local_path: str):
250
  """Download dataset from HF Hub as files"""
251
  print(f"Downloading dataset from {dataset_repo}...")
252
+
 
253
  os.makedirs(local_path, exist_ok=True)
254
+
255
+ local_source = find_local_dataset_source(dataset_repo)
256
+ if local_source:
257
+ print(f"Found local dataset at {local_source}")
258
+ copy_dataset_files(local_source, local_path)
259
+ return
260
+
261
+ repo_id = normalize_repo_id(dataset_repo)
262
+
263
  try:
264
+ dataset = load_dataset(repo_id, split="train")
265
+
 
 
266
  for i, item in enumerate(dataset):
 
267
  if "image" in item:
268
  image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
269
  image = item["image"]
270
+
 
271
  if image.mode == 'RGBA':
 
272
  background = Image.new('RGB', image.size, (255, 255, 255))
273
+ background.paste(image, mask=image.split()[-1])
274
  image = background
275
  elif image.mode not in ['RGB', 'L']:
 
276
  image = image.convert('RGB')
277
+
278
  image.save(image_path, 'JPEG')
279
+
 
280
  if "text" in item:
281
  caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
282
  with open(caption_path, "w", encoding="utf-8") as f:
283
  f.write(item["text"])
284
+
285
  print(f"Downloaded {len(dataset)} items to {local_path}")
286
+
287
  except Exception as e:
288
  print(f"Failed to load as structured dataset: {e}")
289
  print("Attempting to download raw files...")
290
+
291
+ temp_repo_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
292
+
 
 
 
 
 
293
  print(f"Downloaded repo to: {temp_repo_path}")
294
  print(f"Contents: {os.listdir(temp_repo_path)}")
295
+
296
+ copy_dataset_files(temp_repo_path, local_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  def create_config(dataset_path: str, output_path: str):
299
  """Create training configuration"""
 
793
  reject(new Error(`Process error: ${err.message}`));
794
  });
795
  });
796
+ }