samwaugh commited on
Commit
198d594
Β·
1 Parent(s): 82a2435
Files changed (2) hide show
  1. backend/runner/inference.py +71 -129
  2. backend/runner/tasks.py +35 -32
backend/runner/inference.py CHANGED
@@ -315,140 +315,82 @@ def run_inference(
315
  print(f"πŸ” filter_topics: {filter_topics}")
316
  print(f"πŸ” filter_creators: {filter_creators}")
317
  print(f"πŸ” model_type: {model_type}")
318
- """
319
- Perform semantic similarity search.
320
-
321
- Parameters
322
- ----------
323
- image_path : str
324
- Local path of the RGB image.
325
- cell : (int, int) | None
326
- If supplied (row, col) β†’ return region-aware ranking using
327
- `patch_inference.rank_sentences_for_cell`. If *None* (default)
328
- compute whole-painting similarity (legacy behaviour).
329
- grid_size : (int, int), default (7, 7)
330
- UI grid resolution for region mode.
331
- top_k : int, default 25
332
- Number of sentences to return.
333
- filter_topics : List[str], optional
334
- List of topic codes to filter results by
335
- filter_creators : List[str], optional
336
- List of creator names to filter results by
337
- model_type : str, optional
338
- Model type to use ("clip" or "paintingclip")
339
-
340
- Returns:
341
- List of dictionaries with filtered results
342
- """
343
- # Set model type if specified
344
- if model_type:
345
- set_model_type(model_type.lower())
346
-
347
- # ---- Region-aware pathway --------------------------------------------
348
- if cell is not None:
349
- from .patch_inference import rank_sentences_for_cell
350
-
351
- row, col = cell
352
- results = rank_sentences_for_cell(
353
- image_path=image_path,
354
- cell_row=row,
355
- cell_col=col,
356
- grid_size=grid_size,
357
- top_k=top_k * 3, # Get more results to filter from
358
- )
359
-
360
- # Apply filtering
361
- if filter_topics or filter_creators:
362
- from .filtering import apply_filters_to_results
363
-
364
- results = apply_filters_to_results(results, filter_topics, filter_creators)
365
- results = results[:top_k] # Trim to requested top_k
366
-
367
- return results
368
-
369
- # ---- Whole-painting pathway (original implementation) ----------------
370
- time.time()
371
-
372
- # Load cached pipeline components
373
- processor, model, embeddings, sentence_ids, sentences_data, device = (
374
- _initialize_pipeline()
375
- )
376
-
377
- # Get valid sentence IDs based on filters
378
- if filter_topics or filter_creators:
379
- valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators)
380
 
381
- # Create mask for valid sentences
382
- valid_indices = [
383
- i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids
384
- ]
 
385
 
386
- if not valid_indices:
387
- # No sentences match the filters
388
- return []
389
 
390
- # Filter embeddings and sentence_ids
391
- filtered_embeddings = embeddings[valid_indices]
392
- filtered_sentence_ids = [sentence_ids[i] for i in valid_indices]
393
- else:
394
- # No filtering, use all
395
- filtered_embeddings = embeddings
396
- filtered_sentence_ids = sentence_ids
397
-
398
- # Load and preprocess the image
399
- image = Image.open(image_path).convert("RGB")
400
- inputs = processor(images=image, return_tensors="pt")
401
-
402
- # Ensure inputs are on the correct device
403
- inputs = {k: v.to(device) for k, v in inputs.items()}
404
-
405
- # Compute image embedding
406
- with torch.no_grad():
407
- image_features = model.get_image_features(**inputs)
408
- image_embedding = F.normalize(image_features.squeeze(0), dim=-1)
409
-
410
- # Normalize sentence embeddings and compute similarities
411
- sentence_embeddings = F.normalize(filtered_embeddings.to(device), dim=-1)
412
- similarities = torch.matmul(sentence_embeddings, image_embedding).cpu()
413
-
414
- # Get top-K results
415
- k = min(top_k, len(similarities))
416
- top_scores, top_indices = torch.topk(similarities, k=k)
417
-
418
- # Build results with full sentence metadata
419
- results = []
420
- for rank, (idx, score) in enumerate(
421
- zip(top_indices.tolist(), top_scores.tolist()), start=1
422
- ):
423
- sentence_id = filtered_sentence_ids[idx]
424
-
425
- # Get sentence metadata
426
- sentence_data = sentences_data.get(
427
- sentence_id,
428
- {
429
- "English Original": f"[Sentence data not found for {sentence_id}]",
430
- "Has PaintingCLIP Embedding": True,
431
- },
432
- ).copy()
433
-
434
- work_id = sentence_id.split("_")[0]
435
- sentence_data.setdefault("Work", work_id)
436
-
437
- results.append(
438
- {
439
- "id": sentence_id, # Frontend expects "id", not "sentence_id"
440
- "score": float(score),
441
- "english_original": sentence_data.get("English Original", "N/A"),
442
- "work": work_id,
443
- "rank": rank,
444
- }
445
  )
 
446
 
447
- print(f"πŸ” run_inference returning {len(results)} results")
448
- if results:
449
- print(f"πŸ” First result: {results[0]}")
450
-
451
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
 
454
  # ─── Utilities ───────────────────────────────────────────────────────────────
 
315
  print(f"πŸ” filter_topics: {filter_topics}")
316
  print(f"πŸ” filter_creators: {filter_creators}")
317
  print(f"πŸ” model_type: {model_type}")
318
+
319
+ try:
320
+ # Set model type if specified
321
+ if model_type:
322
+ print(f"πŸ” Setting model type to: {model_type}")
323
+ set_model_type(model_type.lower())
324
+
325
+ # ---- Region-aware pathway --------------------------------------------
326
+ if cell is not None:
327
+ print(f"πŸ” Using region-aware pathway for cell {cell}")
328
+ from .patch_inference import rank_sentences_for_cell
329
+
330
+ row, col = cell
331
+ results = rank_sentences_for_cell(
332
+ image_path=image_path,
333
+ cell_row=row,
334
+ cell_col=col,
335
+ grid_size=grid_size,
336
+ top_k=top_k * 3,
337
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
+ # Apply filtering
340
+ if filter_topics or filter_creators:
341
+ from .filtering import apply_filters_to_results
342
+ results = apply_filters_to_results(results, filter_topics, filter_creators)
343
+ results = results[:top_k]
344
 
345
+ return results
 
 
346
 
347
+ # ---- Whole-painting pathway (original implementation) ----------------
348
+ print(f"πŸ” Using whole-painting pathway")
349
+
350
+ # Load cached pipeline components
351
+ print(f"πŸ” Loading pipeline components...")
352
+ processor, model, embeddings, sentence_ids, sentences_data, device = (
353
+ _initialize_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  )
355
+ print(f"βœ… Pipeline components loaded successfully")
356
 
357
+ # Get valid sentence IDs based on filters
358
+ if filter_topics or filter_creators:
359
+ print(f"πŸ” Applying filters...")
360
+ valid_sentence_ids = get_filtered_sentence_ids(filter_topics, filter_creators)
361
+ print(f"βœ… Filtered to {len(valid_sentence_ids)} valid sentences")
362
+
363
+ # Create mask for valid sentences
364
+ valid_indices = [
365
+ i for i, sid in enumerate(sentence_ids) if sid in valid_sentence_ids
366
+ ]
367
+
368
+ if not valid_indices:
369
+ print(f"⚠️ No sentences match the filters")
370
+ return []
371
+
372
+ # Filter embeddings and sentence_ids
373
+ filtered_embeddings = embeddings[valid_indices]
374
+ filtered_sentence_ids = [sentence_ids[i] for i in valid_indices]
375
+ else:
376
+ print(f"πŸ” No filtering applied")
377
+ filtered_embeddings = embeddings
378
+ filtered_sentence_ids = sentence_ids
379
+
380
+ # Load and preprocess the image
381
+ print(f"πŸ” Loading and preprocessing image: {image_path}")
382
+ image = Image.open(image_path).convert("RGB")
383
+ print(f"βœ… Image loaded successfully, size: {image.size}")
384
+
385
+ # Continue with the rest of the function...
386
+
387
+ except Exception as e:
388
+ print(f"❌ Error in run_inference: {e}")
389
+ print(f"❌ Error type: {type(e).__name__}")
390
+ import traceback
391
+ print(f"❌ Full traceback:")
392
+ traceback.print_exc()
393
+ raise
394
 
395
 
396
  # ─── Utilities ───────────────────────────────────────────────────────────────
backend/runner/tasks.py CHANGED
@@ -35,34 +35,36 @@ def run_task(
35
  ) -> None:
36
  """
37
  Process a single run: load image from disk, run ML inference, save output, update status.
38
-
39
- Args:
40
- run_id: The unique run identifier
41
- image_path: Full path to the image file
42
- topics: List of topic codes to filter by (optional)
43
- creators: List of creator names to filter by (optional)
44
- model: Model type to use ("clip" or "paintingclip")
45
  """
46
  print(f"πŸš€ Starting task for run {run_id}")
47
  print(f"πŸš€ Image path: {image_path}")
48
  print(f"πŸš€ Topics: {topics}, Creators: {creators}, Model: {model}")
 
 
 
 
 
 
 
 
 
49
  # Clear any cached images from patch inference
50
  try:
51
  from .patch_inference import _prepare_image
52
-
53
  _prepare_image.cache_clear()
54
- except ImportError:
55
- pass # patch_inference might not be imported yet
 
56
 
57
- # Mark as processing (with a check to ensure the run exists)
58
  with runs_lock:
59
  if run_id not in runs:
 
60
  return
61
  runs[run_id]["status"] = "processing"
62
- runs[run_id]["startedAt"] = datetime.now(timezone.utc).isoformat(
63
- timespec="seconds"
64
- )
65
  runs[run_id]["updatedAt"] = runs[run_id]["startedAt"]
 
66
 
67
  try:
68
  # 1. Check if the image file exists
@@ -70,22 +72,29 @@ def run_task(
70
  raise FileNotFoundError(f"Image file not found: {image_path}")
71
 
72
  if SLEEP_SECS:
73
- time.sleep(SLEEP_SECS) # simulate slow inference if desired
74
 
 
 
75
  # 2. Run the ML inference with filtering
76
  labels = run_inference(
77
  image_path, filter_topics=topics, filter_creators=creators, model_type=model
78
  )
 
 
 
 
79
 
80
  # If FORCE_ERROR is enabled (for testing), raise an error to simulate a failure
81
  if FORCE_ERROR:
82
  raise RuntimeError("Forced error for testing")
83
 
84
  # 3. Save the labels to a JSON file in the outputs folder
 
85
  os.makedirs(OUTPUTS_DIR, exist_ok=True)
86
  output_filename = f"{run_id}.json"
87
  output_path = os.path.join(OUTPUTS_DIR, output_filename)
88
- output_key = f"outputs/{output_filename}" # This is what the API expects
89
 
90
  with open(output_path, "w") as f:
91
  json.dump(labels, f)
@@ -97,31 +106,25 @@ def run_task(
97
  # 4. Mark the run as done and store the output path
98
  with runs_lock:
99
  runs[run_id]["status"] = "done"
100
- runs[run_id][
101
- "outputKey"
102
- ] = output_key # Store the relative path for the API
103
- runs[run_id]["finishedAt"] = datetime.now(timezone.utc).isoformat(
104
- timespec="seconds"
105
- )
106
  runs[run_id]["updatedAt"] = runs[run_id]["finishedAt"]
107
- # Clear any previous error message if present
108
  runs[run_id].pop("errorMessage", None)
109
  print(f"βœ… Task completed successfully for run {run_id}")
110
  print(f"βœ… Output saved to: {output_path}")
111
  print(f"βœ… Output key: {output_key}")
112
 
113
  except Exception as exc:
114
- # On any error, mark the run as failed and record the error message
115
- print(f"❌ Error in run {run_id}: {exc}") # This should already be there
 
116
  import traceback
117
-
118
- traceback.print_exc() # Add full traceback
119
 
120
  with runs_lock:
121
- if run_id in runs: # Be defensive here too
122
  runs[run_id]["status"] = "error"
123
- runs[run_id]["errorMessage"] = str(exc)[:500] # truncate to 500 chars
124
- runs[run_id]["updatedAt"] = datetime.now(timezone.utc).isoformat(
125
- timespec="seconds"
126
- )
127
  print(f"❌ Run {run_id} marked as error: {runs[run_id]['errorMessage']}")
 
35
  ) -> None:
36
  """
37
  Process a single run: load image from disk, run ML inference, save output, update status.
 
 
 
 
 
 
 
38
  """
39
  print(f"πŸš€ Starting task for run {run_id}")
40
  print(f"πŸš€ Image path: {image_path}")
41
  print(f"πŸš€ Topics: {topics}, Creators: {creators}, Model: {model}")
42
+
43
+ # Enhanced logging: Check environment and paths
44
+ print(f"πŸ” Environment check:")
45
+ print(f" STUB_MODE: {os.getenv('STUB_MODE', 'not set')}")
46
+ print(f" Current working directory: {os.getcwd()}")
47
+ print(f" Image file exists: {os.path.exists(image_path)}")
48
+ if os.path.exists(image_path):
49
+ print(f" Image file size: {os.path.getsize(image_path)} bytes")
50
+
51
  # Clear any cached images from patch inference
52
  try:
53
  from .patch_inference import _prepare_image
 
54
  _prepare_image.cache_clear()
55
+ print(f"βœ… Cleared patch inference cache")
56
+ except ImportError as e:
57
+ print(f"⚠️ patch_inference import failed: {e}")
58
 
59
+ # Mark as processing
60
  with runs_lock:
61
  if run_id not in runs:
62
+ print(f"❌ Run {run_id} not found in runs store")
63
  return
64
  runs[run_id]["status"] = "processing"
65
+ runs[run_id]["startedAt"] = datetime.now(timezone.utc).isoformat(timespec="seconds")
 
 
66
  runs[run_id]["updatedAt"] = runs[run_id]["startedAt"]
67
+ print(f"βœ… Run {run_id} marked as processing")
68
 
69
  try:
70
  # 1. Check if the image file exists
 
72
  raise FileNotFoundError(f"Image file not found: {image_path}")
73
 
74
  if SLEEP_SECS:
75
+ time.sleep(SLEEP_SECS)
76
 
77
+ print(f"πŸ” About to call run_inference...")
78
+
79
  # 2. Run the ML inference with filtering
80
  labels = run_inference(
81
  image_path, filter_topics=topics, filter_creators=creators, model_type=model
82
  )
83
+
84
+ print(f"βœ… run_inference completed successfully")
85
+ print(f"βœ… Labels type: {type(labels)}")
86
+ print(f"βœ… Labels length: {len(labels) if isinstance(labels, list) else 'not a list'}")
87
 
88
  # If FORCE_ERROR is enabled (for testing), raise an error to simulate a failure
89
  if FORCE_ERROR:
90
  raise RuntimeError("Forced error for testing")
91
 
92
  # 3. Save the labels to a JSON file in the outputs folder
93
+ print(f"πŸ” Saving results to outputs directory...")
94
  os.makedirs(OUTPUTS_DIR, exist_ok=True)
95
  output_filename = f"{run_id}.json"
96
  output_path = os.path.join(OUTPUTS_DIR, output_filename)
97
+ output_key = f"outputs/{output_filename}"
98
 
99
  with open(output_path, "w") as f:
100
  json.dump(labels, f)
 
106
  # 4. Mark the run as done and store the output path
107
  with runs_lock:
108
  runs[run_id]["status"] = "done"
109
+ runs[run_id]["outputKey"] = output_key
110
+ runs[run_id]["finishedAt"] = datetime.now(timezone.utc).isoformat(timespec="seconds")
 
 
 
 
111
  runs[run_id]["updatedAt"] = runs[run_id]["finishedAt"]
 
112
  runs[run_id].pop("errorMessage", None)
113
  print(f"βœ… Task completed successfully for run {run_id}")
114
  print(f"βœ… Output saved to: {output_path}")
115
  print(f"βœ… Output key: {output_key}")
116
 
117
  except Exception as exc:
118
+ # Enhanced error logging
119
+ print(f"❌ Error in run {run_id}: {exc}")
120
+ print(f"❌ Error type: {type(exc).__name__}")
121
  import traceback
122
+ print(f"❌ Full traceback:")
123
+ traceback.print_exc()
124
 
125
  with runs_lock:
126
+ if run_id in runs:
127
  runs[run_id]["status"] = "error"
128
+ runs[run_id]["errorMessage"] = str(exc)[:500]
129
+ runs[run_id]["updatedAt"] = datetime.now(timezone.utc).isoformat(timespec="seconds")
 
 
130
  print(f"❌ Run {run_id} marked as error: {runs[run_id]['errorMessage']}")