Anthony Liang commited on
Commit
28efb30
·
1 Parent(s): fad52c2

added the functions for now

Browse files
Files changed (1) hide show
  1. trace_inference.py +113 -10
trace_inference.py CHANGED
@@ -10,6 +10,10 @@ import logging
10
  import os
11
  import tempfile
12
  from typing import List, Optional, Tuple
 
 
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -254,6 +258,112 @@ def build_franka_prompt(task: str) -> str:
254
  f'The task is "{task}". Can you predict the trace of the end effector?'
255
  )
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  def run_inference_qwenvl(
259
  image_path: str,
@@ -275,16 +385,6 @@ def run_inference_qwenvl(
275
  (output_dict, prediction_text, overlay_path, trace_points_text)
276
  output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
277
  """
278
- try:
279
- from qwenvl.data.data_processor import preprocess_qwen_visual
280
- except ImportError as e:
281
- return (
282
- {},
283
- "",
284
- None,
285
- f"qwenvl package not found: {e}. Install qwen-vl-finetune or add to PYTHONPATH.",
286
- )
287
-
288
  success, msg = load_model(model_id)
289
  if not success:
290
  return {}, msg, None, ""
@@ -315,6 +415,9 @@ def run_inference_qwenvl(
315
  [inference_sample], processor, add_gen_prompt=True
316
  )
317
 
 
 
 
318
  input_ids = processed_data["input_ids"].to(model.device)
319
  pixel_values = (
320
  processed_data["pixel_values"].to(model.device)
 
10
  import os
11
  import tempfile
12
  from typing import List, Optional, Tuple
13
+ import re
14
+ from pathlib import Path
15
+ import torch
16
+ from typing import Dict, Any
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
258
  f'The task is "{task}". Can you predict the trace of the end effector?'
259
  )
260
 
261
+ def _make_abs_paths(base: Path, files: str) -> str:
262
+ return f"{(base / files).resolve()}"
263
+
264
+ def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
265
+ # Extract and normalize images and videos
266
+ images = item.get("image") or []
267
+ if isinstance(images, str):
268
+ images = [images]
269
+
270
+ videos = item.get("video") or []
271
+ if isinstance(videos, str):
272
+ videos = [videos]
273
+
274
+ # Build media pools with absolute paths
275
+ image_pool = [
276
+ {"type": "image", "image": _make_abs_paths(base_path, img)} for img in images
277
+ ]
278
+ video_pool = [
279
+ {"type": "video", "video": _make_abs_paths(base_path, vid)} for vid in videos
280
+ ]
281
+
282
+ messages = []
283
+ for turn in item["conversations"]:
284
+ role = "user" if turn["from"] == "human" else "assistant"
285
+ text: str = turn["value"]
286
+
287
+ if role == "user":
288
+ content = []
289
+ # Split text by <image> or <video> placeholders while keeping delimiters
290
+ text_parts = re.split(r"(<image>|<video>)", text)
291
+
292
+ for seg in text_parts:
293
+ if seg == "<image>":
294
+ if not image_pool:
295
+ raise ValueError(
296
+ "Number of <image> placeholders exceeds the number of provided images"
297
+ )
298
+ content.append(image_pool.pop(0))
299
+ elif seg == "<video>":
300
+ if not video_pool:
301
+ raise ValueError(
302
+ "Number of <video> placeholders exceeds the number of provided videos"
303
+ )
304
+ content.append(video_pool.pop(0))
305
+ elif seg.strip():
306
+ content.append({"type": "text", "text": seg.strip()})
307
+
308
+ messages.append({"role": role, "content": content})
309
+ else:
310
+ # Assistant messages contain only text
311
+ messages.append({"role": role, "content": [{"type": "text", "text": text}]})
312
+
313
+ # Check for unused media files
314
+ if image_pool:
315
+ raise ValueError(
316
+ f"{len(image_pool)} image(s) remain unused (not consumed by placeholders)"
317
+ )
318
+ if video_pool:
319
+ raise ValueError(
320
+ f"{len(video_pool)} video(s) remain unused (not consumed by placeholders)"
321
+ )
322
+
323
+ return messages
324
+
325
+ IGNORE_INDEX = -100
326
+
327
+ def preprocess_qwen_visual(
328
+ sources,
329
+ processor,
330
+ ) -> Dict:
331
+ if len(sources) != 1:
332
+ raise ValueError(f"Expected 1 source, got {len(sources)}")
333
+
334
+ source = sources[0]
335
+ base_path = Path(source.get("data_path", ""))
336
+ messages = _build_messages(source, base_path)
337
+
338
+ full_result = processor.apply_chat_template(
339
+ messages, tokenize=True, return_dict=True, return_tensors="pt"
340
+ )
341
+
342
+ input_ids = full_result["input_ids"]
343
+ if isinstance(input_ids, list):
344
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
345
+
346
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
347
+
348
+ input_ids_flat = input_ids[0].tolist()
349
+ L = len(input_ids_flat)
350
+ pos = 0
351
+ while pos < L:
352
+ if input_ids_flat[pos] == 77091:
353
+ ans_start = pos + 2
354
+ ans_end = ans_start
355
+ while ans_end < L and input_ids_flat[ans_end] != 151645:
356
+ ans_end += 1
357
+ if ans_end < L:
358
+ labels[0, ans_start : ans_end + 2] = input_ids[
359
+ 0, ans_start : ans_end + 2
360
+ ]
361
+ pos = ans_end
362
+ pos += 1
363
+
364
+ full_result["labels"] = labels
365
+ full_result["input_ids"] = input_ids
366
+ return full_result
367
 
368
  def run_inference_qwenvl(
369
  image_path: str,
 
385
  (output_dict, prediction_text, overlay_path, trace_points_text)
386
  output_dict has format: {"id", "image", "conversations": [human_msg, gpt_msg]}
387
  """
 
 
 
 
 
 
 
 
 
 
388
  success, msg = load_model(model_id)
389
  if not success:
390
  return {}, msg, None, ""
 
415
  [inference_sample], processor, add_gen_prompt=True
416
  )
417
 
418
+ print("processed_data")
419
+ print(processed_data)
420
+
421
  input_ids = processed_data["input_ids"].to(model.device)
422
  pixel_values = (
423
  processed_data["pixel_values"].to(model.device)