cmpatino HF Staff commited on
Commit
40112a6
·
1 Parent(s): 35f1842

Add more efficient processing

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -120,9 +120,7 @@ def process_texts(student_model_id, teacher_model_id, dataset_id, progress=gr.Pr
120
  ds = load_dataset(dataset_id, split="train")
121
  rows = ds.select(range(min(10, len(ds))))
122
 
123
- html_blocks = []
124
- for row_idx, row in enumerate(rows):
125
- progress((row_idx + 1) / 12, desc=f"Processing text {row_idx + 1}/10...")
126
  text = "".join(msg["content"] for msg in row["messages"])
127
 
128
  s_ids = student_tokenizer.encode(text, add_special_tokens=False)
@@ -133,13 +131,19 @@ def process_texts(student_model_id, teacher_model_id, dataset_id, progress=gr.Pr
133
  )
134
 
135
  highlighted = highlight_groups(student_tokenizer, s_ids, s_groups, t_groups)
136
- html_blocks.append(
137
- f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
138
- f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
139
- f"<strong>Text {row_idx + 1}</strong> "
140
- f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
141
- f"{highlighted}</div>"
142
- )
 
 
 
 
 
 
143
 
144
  progress(1, desc="Done!")
145
 
 
120
  ds = load_dataset(dataset_id, split="train")
121
  rows = ds.select(range(min(10, len(ds))))
122
 
123
+ def make_html_block(row, idx):
 
 
124
  text = "".join(msg["content"] for msg in row["messages"])
125
 
126
  s_ids = student_tokenizer.encode(text, add_special_tokens=False)
 
131
  )
132
 
133
  highlighted = highlight_groups(student_tokenizer, s_ids, s_groups, t_groups)
134
+ return {
135
+ "html_block": (
136
+ f'<div style="border:1px solid #ccc; padding:10px; margin:10px 0; '
137
+ f'border-radius:5px; white-space:pre-wrap; font-family:monospace; font-size:13px;">'
138
+ f"<strong>Text {idx + 1}</strong> "
139
+ f"(student tokens: {len(s_ids)}, teacher tokens: {len(t_ids)})<br><br>"
140
+ f"{highlighted}</div>"
141
+ )
142
+ }
143
+
144
+ progress(0.2, desc="Processing texts...")
145
+ rows = rows.map(make_html_block, num_proc=4, with_indices=True)
146
+ html_blocks = rows["html_block"]
147
 
148
  progress(1, desc="Done!")
149