aractingi commited on
Commit
a17ffcc
·
1 Parent(s): de93564

Feat/visualize language instruction (#7)

Browse files

* add general support to print `task` as language instruction

* add support to print `language instructions` instead of `task` even if there are multiple entry like droid

src/app/[org]/[dataset]/[episode]/episode-viewer.tsx CHANGED
@@ -46,6 +46,7 @@ function EpisodeViewerInner({ data, org, dataset }: { data: any; org?: string; d
46
  videosInfo,
47
  chartDataGroups,
48
  episodes,
 
49
  } = data;
50
 
51
  const [videosReady, setVideosReady] = useState(!videosInfo.length);
@@ -227,6 +228,22 @@ function EpisodeViewerInner({ data, org, dataset }: { data: any; org?: string; d
227
  />
228
  )}
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  {/* Graph */}
231
  <div className="mb-4">
232
  <DataRecharts
 
46
  videosInfo,
47
  chartDataGroups,
48
  episodes,
49
+ task,
50
  } = data;
51
 
52
  const [videosReady, setVideosReady] = useState(!videosInfo.length);
 
228
  />
229
  )}
230
 
231
+ {/* Language Instruction */}
232
+ {task && (
233
+ <div className="mb-6 p-4 bg-slate-800 rounded-lg border border-slate-600">
234
+ <p className="text-slate-300">
235
+ <span className="font-semibold text-slate-100">Language Instruction:</span>
236
+ </p>
237
+ <div className="mt-2 text-slate-300">
238
+ {task.split('\n').map((instruction, index) => (
239
+ <p key={index} className="mb-1">
240
+ {instruction}
241
+ </p>
242
+ ))}
243
+ </div>
244
+ </div>
245
+ )}
246
+
247
  {/* Graph */}
248
  <div className="mb-4">
249
  <DataRecharts
src/app/[org]/[dataset]/[episode]/fetch-data.ts CHANGED
@@ -192,6 +192,78 @@ async function getEpisodeDataV2(
192
  );
193
 
194
  const arrayBuffer = await fetchParquetFile(parquetUrl);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  const data = await readParquetColumn(arrayBuffer, filteredColumnNames);
196
  // Flatten and map to array of objects for chartData
197
  const seriesNames = [
@@ -283,7 +355,7 @@ async function getEpisodeDataV2(
283
  // suffixGroupArr is array of suffix groups (each is array of keys)
284
  const merged = suffixGroupArr.flat();
285
  if (merged.length > 6) {
286
- const subgroups = [];
287
  for (let i = 0; i < merged.length; i += 6) {
288
  subgroups.push(merged.slice(i, i + 6));
289
  }
@@ -337,6 +409,7 @@ async function getEpisodeDataV2(
337
  episodes,
338
  ignoredColumns,
339
  duration,
 
340
  };
341
  }
342
 
@@ -365,7 +438,7 @@ async function getEpisodeDataV3(
365
  const videosInfo = extractVideoInfoV3WithSegmentation(repoId, version, info, episodeMetadata);
366
 
367
  // Load episode data for charts
368
- const { chartDataGroups, ignoredColumns } = await loadEpisodeDataV3(repoId, version, info, episodeMetadata);
369
 
370
  // Calculate duration from episode length and FPS if available
371
  const duration = episodeMetadata.length ? episodeMetadata.length / info.fps :
@@ -379,6 +452,7 @@ async function getEpisodeDataV3(
379
  episodes,
380
  ignoredColumns,
381
  duration,
 
382
  };
383
  }
384
 
@@ -388,7 +462,7 @@ async function loadEpisodeDataV3(
388
  version: string,
389
  info: DatasetMetadata,
390
  episodeMetadata: any,
391
- ): Promise<{ chartDataGroups: any[]; ignoredColumns: string[] }> {
392
  // Build data file path using chunk and file indices
393
  const dataChunkIndex = episodeMetadata.data_chunk_index || 0;
394
  const dataFileIndex = episodeMetadata.data_file_index || 0;
@@ -397,7 +471,7 @@ async function loadEpisodeDataV3(
397
  try {
398
  const dataUrl = buildVersionedUrl(repoId, version, dataPath);
399
  const arrayBuffer = await fetchParquetFile(dataUrl);
400
- const fullData = await readParquetColumn(arrayBuffer, []);
401
 
402
  // Extract the episode-specific data slice
403
  // Convert BigInt to number if needed
@@ -406,15 +480,87 @@ async function loadEpisodeDataV3(
406
  const episodeData = fullData.slice(fromIndex, toIndex);
407
 
408
  if (episodeData.length === 0) {
409
- return { chartDataGroups: [], ignoredColumns: [] };
410
  }
411
 
412
  // Convert to the same format as v2.x for compatibility with existing chart code
413
  const { chartDataGroups, ignoredColumns } = processEpisodeDataForCharts(episodeData, info, episodeMetadata);
414
 
415
- return { chartDataGroups, ignoredColumns };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  } catch {
417
- return { chartDataGroups: [], ignoredColumns: [] };
418
  }
419
  }
420
 
@@ -465,7 +611,7 @@ function processEpisodeDataForCharts(
465
  }
466
  });
467
 
468
- // Columns to exclude from charts
469
  const excludedColumns = ['index', 'task_index', 'episode_index', 'frame_index', 'next.done'];
470
 
471
  // Create columns structure similar to V2.1 for proper hierarchical naming
@@ -831,7 +977,7 @@ function parseEpisodeRowSimple(row: any): any {
831
  return parseInt(value) || 0;
832
  };
833
 
834
- const episodeData = {
835
  episode_index: toBigIntSafe(row['episode_index']),
836
  data_chunk_index: toBigIntSafe(row['data/chunk_index']),
837
  data_file_index: toBigIntSafe(row['data/file_index']),
 
192
  );
193
 
194
  const arrayBuffer = await fetchParquetFile(parquetUrl);
195
+
196
+ // Extract task - first check for language instructions (preferred), then fallback to task field or tasks.jsonl
197
+ let task: string | undefined;
198
+ let allData: any[] = [];
199
+
200
+ // Load data first
201
+ try {
202
+ allData = await readParquetAsObjects(arrayBuffer, []);
203
+ } catch (error) {
204
+ // Could not read parquet data
205
+ }
206
+
207
+ // First check for language_instruction fields in the data (preferred)
208
+ if (allData.length > 0) {
209
+ const firstRow = allData[0];
210
+ const languageInstructions: string[] = [];
211
+
212
+ // Check for language_instruction field
213
+ if (firstRow.language_instruction) {
214
+ languageInstructions.push(firstRow.language_instruction);
215
+ }
216
+
217
+ // Check for numbered language_instruction fields
218
+ let instructionNum = 2;
219
+ while (firstRow[`language_instruction_${instructionNum}`]) {
220
+ languageInstructions.push(firstRow[`language_instruction_${instructionNum}`]);
221
+ instructionNum++;
222
+ }
223
+
224
+ // Join all instructions with line breaks
225
+ if (languageInstructions.length > 0) {
226
+ task = languageInstructions.join('\n');
227
+ }
228
+ }
229
+
230
+ // If no language instructions found, try direct task field
231
+ if (!task && allData.length > 0 && allData[0].task) {
232
+ task = allData[0].task;
233
+ }
234
+
235
+ // If still no task found, try loading from tasks.jsonl metadata file (v2.x format)
236
+ if (!task && allData.length > 0) {
237
+ try {
238
+ const tasksUrl = buildVersionedUrl(repoId, version, "meta/tasks.jsonl");
239
+ const tasksResponse = await fetch(tasksUrl);
240
+
241
+ if (tasksResponse.ok) {
242
+ const tasksText = await tasksResponse.text();
243
+ // Parse JSONL format (one JSON object per line)
244
+ const tasksData = tasksText
245
+ .split('\n')
246
+ .filter(line => line.trim())
247
+ .map(line => JSON.parse(line));
248
+
249
+ if (tasksData && tasksData.length > 0) {
250
+ const taskIndex = allData[0].task_index;
251
+
252
+ // Convert BigInt to number for comparison
253
+ const taskIndexNum = typeof taskIndex === 'bigint' ? Number(taskIndex) : taskIndex;
254
+
255
+ // Find task by task_index
256
+ const taskData = tasksData.find(t => t.task_index === taskIndexNum);
257
+ if (taskData) {
258
+ task = taskData.task;
259
+ }
260
+ }
261
+ }
262
+ } catch (error) {
263
+ // No tasks metadata file for this v2.x dataset
264
+ }
265
+ }
266
+
267
  const data = await readParquetColumn(arrayBuffer, filteredColumnNames);
268
  // Flatten and map to array of objects for chartData
269
  const seriesNames = [
 
355
  // suffixGroupArr is array of suffix groups (each is array of keys)
356
  const merged = suffixGroupArr.flat();
357
  if (merged.length > 6) {
358
+ const subgroups: string[][] = [];
359
  for (let i = 0; i < merged.length; i += 6) {
360
  subgroups.push(merged.slice(i, i + 6));
361
  }
 
409
  episodes,
410
  ignoredColumns,
411
  duration,
412
+ task,
413
  };
414
  }
415
 
 
438
  const videosInfo = extractVideoInfoV3WithSegmentation(repoId, version, info, episodeMetadata);
439
 
440
  // Load episode data for charts
441
+ const { chartDataGroups, ignoredColumns, task } = await loadEpisodeDataV3(repoId, version, info, episodeMetadata);
442
 
443
  // Calculate duration from episode length and FPS if available
444
  const duration = episodeMetadata.length ? episodeMetadata.length / info.fps :
 
452
  episodes,
453
  ignoredColumns,
454
  duration,
455
+ task,
456
  };
457
  }
458
 
 
462
  version: string,
463
  info: DatasetMetadata,
464
  episodeMetadata: any,
465
+ ): Promise<{ chartDataGroups: any[]; ignoredColumns: string[]; task?: string }> {
466
  // Build data file path using chunk and file indices
467
  const dataChunkIndex = episodeMetadata.data_chunk_index || 0;
468
  const dataFileIndex = episodeMetadata.data_file_index || 0;
 
471
  try {
472
  const dataUrl = buildVersionedUrl(repoId, version, dataPath);
473
  const arrayBuffer = await fetchParquetFile(dataUrl);
474
+ const fullData = await readParquetAsObjects(arrayBuffer, []);
475
 
476
  // Extract the episode-specific data slice
477
  // Convert BigInt to number if needed
 
480
  const episodeData = fullData.slice(fromIndex, toIndex);
481
 
482
  if (episodeData.length === 0) {
483
+ return { chartDataGroups: [], ignoredColumns: [], task: undefined };
484
  }
485
 
486
  // Convert to the same format as v2.x for compatibility with existing chart code
487
  const { chartDataGroups, ignoredColumns } = processEpisodeDataForCharts(episodeData, info, episodeMetadata);
488
 
489
+ // First check for language_instruction fields in the data (preferred)
490
+ let task: string | undefined;
491
+ if (episodeData.length > 0) {
492
+ const firstRow = episodeData[0];
493
+ const languageInstructions: string[] = [];
494
+
495
+ // Check for language_instruction field
496
+ if (firstRow.language_instruction) {
497
+ languageInstructions.push(firstRow.language_instruction);
498
+ }
499
+
500
+ // Check for numbered language_instruction fields
501
+ let instructionNum = 2;
502
+ while (firstRow[`language_instruction_${instructionNum}`]) {
503
+ languageInstructions.push(firstRow[`language_instruction_${instructionNum}`]);
504
+ instructionNum++;
505
+ }
506
+
507
+ // If no instructions found in first row, check a few more rows
508
+ if (languageInstructions.length === 0 && episodeData.length > 1) {
509
+ const middleIndex = Math.floor(episodeData.length / 2);
510
+ const lastIndex = episodeData.length - 1;
511
+
512
+ [middleIndex, lastIndex].forEach((idx) => {
513
+ const row = episodeData[idx];
514
+
515
+ if (row.language_instruction && languageInstructions.length === 0) {
516
+ // Use this row's instructions
517
+ if (row.language_instruction) {
518
+ languageInstructions.push(row.language_instruction);
519
+ }
520
+ let num = 2;
521
+ while (row[`language_instruction_${num}`]) {
522
+ languageInstructions.push(row[`language_instruction_${num}`]);
523
+ num++;
524
+ }
525
+ }
526
+ });
527
+ }
528
+
529
+ // Join all instructions with line breaks
530
+ if (languageInstructions.length > 0) {
531
+ task = languageInstructions.join('\n');
532
+ }
533
+ }
534
+
535
+ // If no language instructions found, fall back to tasks metadata
536
+ if (!task) {
537
+ try {
538
+ // Load tasks metadata
539
+ const tasksUrl = buildVersionedUrl(repoId, version, "meta/tasks.parquet");
540
+ const tasksArrayBuffer = await fetchParquetFile(tasksUrl);
541
+ const tasksData = await readParquetAsObjects(tasksArrayBuffer, []);
542
+
543
+ if (episodeData.length > 0 && tasksData && tasksData.length > 0) {
544
+ const taskIndex = episodeData[0].task_index;
545
+
546
+ // Convert BigInt to number for comparison
547
+ const taskIndexNum = typeof taskIndex === 'bigint' ? Number(taskIndex) : taskIndex;
548
+
549
+ // Look up task by index
550
+ if (taskIndexNum !== undefined && taskIndexNum < tasksData.length) {
551
+ const taskData = tasksData[taskIndexNum];
552
+ // Extract task from __index_level_0__ field
553
+ task = taskData.__index_level_0__ || taskData.task || taskData['task'] || taskData[0];
554
+ }
555
+ }
556
+ } catch (error) {
557
+ // Could not load tasks metadata - dataset might not have language tasks
558
+ }
559
+ }
560
+
561
+ return { chartDataGroups, ignoredColumns, task };
562
  } catch {
563
+ return { chartDataGroups: [], ignoredColumns: [], task: undefined };
564
  }
565
  }
566
 
 
611
  }
612
  });
613
 
614
+ // Columns to exclude from charts (note: 'task' is intentionally not excluded as we want to access it)
615
  const excludedColumns = ['index', 'task_index', 'episode_index', 'frame_index', 'next.done'];
616
 
617
  // Create columns structure similar to V2.1 for proper hierarchical naming
 
977
  return parseInt(value) || 0;
978
  };
979
 
980
+ const episodeData: any = {
981
  episode_index: toBigIntSafe(row['episode_index']),
982
  data_chunk_index: toBigIntSafe(row['data/chunk_index']),
983
  data_file_index: toBigIntSafe(row['data/file_index']),