Anthony Liang commited on
Commit
130aa46
·
1 Parent(s): 28efb30

psuhing some code

Browse files
Files changed (1) hide show
  1. trace_inference.py +53 -21
trace_inference.py CHANGED
@@ -327,7 +327,18 @@ IGNORE_INDEX = -100
327
  def preprocess_qwen_visual(
328
  sources,
329
  processor,
 
330
  ) -> Dict:
 
 
 
 
 
 
 
 
 
 
331
  if len(sources) != 1:
332
  raise ValueError(f"Expected 1 source, got {len(sources)}")
333
 
@@ -336,33 +347,39 @@ def preprocess_qwen_visual(
336
  messages = _build_messages(source, base_path)
337
 
338
  full_result = processor.apply_chat_template(
339
- messages, tokenize=True, return_dict=True, return_tensors="pt"
 
 
 
 
340
  )
341
 
342
  input_ids = full_result["input_ids"]
343
  if isinstance(input_ids, list):
344
  input_ids = torch.tensor(input_ids).unsqueeze(0)
345
 
346
- labels = torch.full_like(input_ids, IGNORE_INDEX)
347
-
348
- input_ids_flat = input_ids[0].tolist()
349
- L = len(input_ids_flat)
350
- pos = 0
351
- while pos < L:
352
- if input_ids_flat[pos] == 77091:
353
- ans_start = pos + 2
354
- ans_end = ans_start
355
- while ans_end < L and input_ids_flat[ans_end] != 151645:
356
- ans_end += 1
357
- if ans_end < L:
358
- labels[0, ans_start : ans_end + 2] = input_ids[
359
- 0, ans_start : ans_end + 2
360
- ]
361
- pos = ans_end
362
- pos += 1
363
-
364
- full_result["labels"] = labels
365
  full_result["input_ids"] = input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  return full_result
367
 
368
  def run_inference_qwenvl(
@@ -407,6 +424,13 @@ def run_inference_qwenvl(
407
  "data_path": data_path,
408
  }
409
 
 
 
 
 
 
 
 
410
  try:
411
  import torch
412
  from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
@@ -415,6 +439,8 @@ def run_inference_qwenvl(
415
  [inference_sample], processor, add_gen_prompt=True
416
  )
417
 
 
 
418
  print("processed_data")
419
  print(processed_data)
420
 
@@ -437,7 +463,7 @@ def run_inference_qwenvl(
437
  inputs["image_grid_thw"] = image_grid_thw
438
 
439
  with torch.no_grad():
440
- generated_ids = model.generate(**inputs, max_new_tokens=1024)
441
  generated_ids_trimmed = [
442
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
443
  ]
@@ -445,6 +471,9 @@ def run_inference_qwenvl(
445
  generated_ids_trimmed[0], skip_special_tokens=True
446
  )
447
 
 
 
 
448
  # Format output like the example: "Trace: [[x,y], [x,y], ...]"
449
  trajectories = extract_trajectory_from_text(prediction)
450
  trace_value = f"Trace: {trajectories}" if trajectories else f"Trace: {prediction}"
@@ -459,6 +488,9 @@ def run_inference_qwenvl(
459
 
460
  trace_points_text = format_trace_points(trajectories)
461
 
 
 
 
462
  overlay_path = None
463
  if trajectories and len(trajectories) >= 2:
464
  _, preprocessed_path = preprocess_image_for_trace(image_path)
 
327
  def preprocess_qwen_visual(
328
  sources,
329
  processor,
330
+ add_gen_prompt: bool = False,
331
  ) -> Dict:
332
+ """
333
+ Preprocess one sample for Qwen-VL.
334
+
335
+ Args:
336
+ sources: List of one dict with keys: image, conversations, data_path.
337
+ processor: Qwen-VL processor.
338
+ add_gen_prompt: If True, add generation prompt so the model generates the
339
+ assistant reply (use for inference). If False, full conversation is
340
+ tokenized and labels are built for training.
341
+ """
342
  if len(sources) != 1:
343
  raise ValueError(f"Expected 1 source, got {len(sources)}")
344
 
 
347
  messages = _build_messages(source, base_path)
348
 
349
  full_result = processor.apply_chat_template(
350
+ messages,
351
+ tokenize=True,
352
+ return_dict=True,
353
+ return_tensors="pt",
354
+ add_generation_prompt=add_gen_prompt,
355
  )
356
 
357
  input_ids = full_result["input_ids"]
358
  if isinstance(input_ids, list):
359
  input_ids = torch.tensor(input_ids).unsqueeze(0)
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  full_result["input_ids"] = input_ids
362
+
363
+ # Labels are only needed for training; skip for generation
364
+ if not add_gen_prompt:
365
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
366
+ input_ids_flat = input_ids[0].tolist()
367
+ L = len(input_ids_flat)
368
+ pos = 0
369
+ while pos < L:
370
+ if input_ids_flat[pos] == 77091:
371
+ ans_start = pos + 2
372
+ ans_end = ans_start
373
+ while ans_end < L and input_ids_flat[ans_end] != 151645:
374
+ ans_end += 1
375
+ if ans_end < L:
376
+ labels[0, ans_start : ans_end + 2] = input_ids[
377
+ 0, ans_start : ans_end + 2
378
+ ]
379
+ pos = ans_end
380
+ pos += 1
381
+ full_result["labels"] = labels
382
+
383
  return full_result
384
 
385
  def run_inference_qwenvl(
 
424
  "data_path": data_path,
425
  }
426
 
427
+ print("prompt")
428
+ print(prompt)
429
+ print("image_path")
430
+ print(image_rel)
431
+ print("data_path")
432
+ print(data_path)
433
+
434
  try:
435
  import torch
436
  from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
 
439
  [inference_sample], processor, add_gen_prompt=True
440
  )
441
 
442
+ print("inference_sample")
443
+ print(inference_sample)
444
  print("processed_data")
445
  print(processed_data)
446
 
 
463
  inputs["image_grid_thw"] = image_grid_thw
464
 
465
  with torch.no_grad():
466
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
467
  generated_ids_trimmed = [
468
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids, generated_ids)
469
  ]
 
471
  generated_ids_trimmed[0], skip_special_tokens=True
472
  )
473
 
474
+ print("prediction")
475
+ print(prediction)
476
+
477
  # Format output like the example: "Trace: [[x,y], [x,y], ...]"
478
  trajectories = extract_trajectory_from_text(prediction)
479
  trace_value = f"Trace: {trajectories}" if trajectories else f"Trace: {prediction}"
 
488
 
489
  trace_points_text = format_trace_points(trajectories)
490
 
491
+ print("trace_points_text")
492
+ print(trace_points_text)
493
+
494
  overlay_path = None
495
  if trajectories and len(trajectories) >= 2:
496
  _, preprocessed_path = preprocess_image_for_trace(image_path)