ludocomito commited on
Commit
2b8d80f
·
1 Parent(s): 3a52c48

improved prediction viz

Browse files
Files changed (1) hide show
  1. app.py +49 -22
app.py CHANGED
@@ -290,7 +290,7 @@ def run_inference(prompt_cell, prompt_cond, query_cell):
290
 
291
  time.sleep(0.5)
292
 
293
- # Result state - show predicted cells with the new condition applied
294
  predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True)
295
 
296
  result_html = f'''
@@ -300,43 +300,58 @@ def run_inference(prompt_cell, prompt_cond, query_cell):
300
  <div style="text-align: center; margin-bottom: 24px;">
301
  <div style="font-family: monospace; font-size: 10px; color: #64748b; letter-spacing: 0.15em; margin-bottom: 8px; text-transform: uppercase;">PREDICTION COMPLETE</div>
302
  <div style="font-size: 14px; font-weight: 600; color: #1e293b;">
303
- {query_cell_data["name"]} gene expression under {prompt_cond_data["name"]}
 
 
 
304
  </div>
305
  </div>
306
 
307
- <!-- Predicted Cells -->
308
- <div style="display: flex; justify-content: center; margin-bottom: 24px;">
309
- <div style="background: {prompt_cond_data["bg"]}20; border: 2px solid {prompt_cond_data["color"]}60; border-radius: 12px; padding: 16px 24px; display: flex; gap: 8px; align-items: center;">
310
- {predicted_cells}
311
- </div>
312
  </div>
313
 
314
- <!-- Output columns visualization -->
315
- <div style="display: flex; justify-content: center; gap: 12px; margin-bottom: 20px;">
316
  <div style="text-align: center;">
317
- <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
318
  {generate_output_column(prompt_cond_data["color"])}
319
  </div>
320
- <div style="font-size: 10px; color: #64748b; margin-top: 4px;">c₁</div>
 
 
 
321
  </div>
322
  <div style="text-align: center;">
323
- <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
324
  {generate_output_column(prompt_cond_data["color"])}
325
  </div>
326
- <div style="font-size: 10px; color: #64748b; margin-top: 4px;">c₂</div>
 
 
 
327
  </div>
328
- <div style="display: flex; align-items: center; color: #64748b; font-size: 14px; letter-spacing: 3px;">···</div>
329
  <div style="text-align: center;">
330
- <div style="background: rgba(99, 102, 241, 0.2); border: 1px solid rgba(99, 102, 241, 0.4); border-radius: 8px; padding: 8px 12px; display: flex; flex-direction: column; gap: 4px;">
331
  {generate_output_column(prompt_cond_data["color"])}
332
  </div>
333
- <div style="font-size: 10px; color: #64748b; margin-top: 4px;">cₙ</div>
 
 
 
334
  </div>
335
  </div>
336
 
337
  <!-- Description -->
338
- <div style="text-align: center; font-size: 11px; color: #64748b; max-width: 300px; margin: 0 auto; line-height: 1.5;">
339
- Zero-shot prediction of gene expression counts using in-context learning from {prompt_cell_data["name"]} response to {prompt_cond_data["name"]}.
 
 
 
340
  </div>
341
 
342
  </div>
@@ -345,11 +360,23 @@ def run_inference(prompt_cell, prompt_cond, query_cell):
345
  yield result_html, gr.update(visible=True), gr.update(visible=False)
346
 
347
  def generate_output_column(color):
348
- """Generate a vertical column of gene expression values"""
 
349
  cells = []
350
- for _ in range(5):
351
- opacity = 0.2 + random.random() * 0.6
352
- cells.append(f'<div style="width: 16px; height: 16px; background: {color}; opacity: {opacity:.1f}; border-radius: 3px;"></div>')
 
 
 
 
 
 
 
 
 
 
 
353
  return '\n'.join(cells)
354
 
355
  def reset_inference(prompt_cell, prompt_cond, query_cell):
 
290
 
291
  time.sleep(0.5)
292
 
293
+ # Result state - show predicted gene counts per query cell
294
  predicted_cells = generate_cell_array(query_cell, prompt_cond, 5, is_prompt=True)
295
 
296
  result_html = f'''
 
300
  <div style="text-align: center; margin-bottom: 24px;">
301
  <div style="font-family: monospace; font-size: 10px; color: #64748b; letter-spacing: 0.15em; margin-bottom: 8px; text-transform: uppercase;">PREDICTION COMPLETE</div>
302
  <div style="font-size: 14px; font-weight: 600; color: #1e293b;">
303
+ Predicted gene expression counts for {query_cell_data["name"]}
304
+ </div>
305
+ <div style="font-size: 11px; color: #64748b; margin-top: 4px;">
306
+ under <span style="color: {prompt_cond_data["color"]}; font-weight: 600;">{prompt_cond_data["name"]}</span> condition
307
  </div>
308
  </div>
309
 
310
+ <!-- Section label -->
311
+ <div style="display: flex; align-items: center; gap: 12px; margin-bottom: 16px; justify-content: center;">
312
+ <div style="height: 1px; width: 60px; background: linear-gradient(90deg, transparent, #cbd5e1);"></div>
313
+ <div style="font-size: 10px; font-weight: 600; color: #475569; text-transform: uppercase; letter-spacing: 0.08em;">Predicted Gene Counts Per Query Cell</div>
314
+ <div style="height: 1px; width: 60px; background: linear-gradient(90deg, #cbd5e1, transparent);"></div>
315
  </div>
316
 
317
+ <!-- Output columns visualization - Gene counts per cell -->
318
+ <div style="display: flex; justify-content: center; gap: 16px; margin-bottom: 20px;">
319
  <div style="text-align: center;">
320
+ <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);">
321
  {generate_output_column(prompt_cond_data["color"])}
322
  </div>
323
+ <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;">
324
+ {make_cell_svg(query_cell_data["color"], False, 18)}
325
+ <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell 1</div>
326
+ </div>
327
  </div>
328
  <div style="text-align: center;">
329
+ <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);">
330
  {generate_output_column(prompt_cond_data["color"])}
331
  </div>
332
+ <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;">
333
+ {make_cell_svg(query_cell_data["color"], False, 18)}
334
+ <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell 2</div>
335
+ </div>
336
  </div>
337
+ <div style="display: flex; align-items: center; color: #94a3b8; font-size: 16px; font-weight: bold; letter-spacing: 3px; padding-bottom: 24px;">···</div>
338
  <div style="text-align: center;">
339
+ <div style="background: white; border: 2px solid {query_cell_data["color"]}40; border-radius: 10px; padding: 10px 14px; display: flex; flex-direction: column; gap: 2px; box-shadow: 0 2px 8px rgba(0,0,0,0.05);">
340
  {generate_output_column(prompt_cond_data["color"])}
341
  </div>
342
+ <div style="display: flex; align-items: center; justify-content: center; gap: 4px; margin-top: 6px;">
343
+ {make_cell_svg(query_cell_data["color"], False, 18)}
344
+ <div style="font-size: 9px; color: #475569; font-weight: 500;">Cell n</div>
345
+ </div>
346
  </div>
347
  </div>
348
 
349
  <!-- Description -->
350
+ <div style="text-align: center; font-size: 11px; color: #64748b; max-width: 360px; margin: 0 auto; line-height: 1.6;">
351
+ <strong style="color: #475569;">Zero-shot prediction:</strong> Using in-context learning from
352
+ <span style="color: {prompt_cond_data["color"]}; font-weight: 500;">{prompt_cell_data["name"]}</span> response,
353
+ STACK predicts gene counts for each <span style="color: {query_cell_data["color"]}; font-weight: 500;">{query_cell_data["name"]}</span>
354
+ under the same perturbation.
355
  </div>
356
 
357
  </div>
 
360
  yield result_html, gr.update(visible=True), gr.update(visible=False)
361
 
362
  def generate_output_column(color):
363
+ """Generate a vertical column of gene expression counts showing explicit values"""
364
+ gene_names = ["g₁", "g₂", "g₃", "g₄", "g₅"]
365
  cells = []
366
+ for i, gene in enumerate(gene_names):
367
+ # Generate pseudo gene count value
368
+ count = random.randint(10, 500)
369
+ bar_width = min(count / 500 * 40, 40) # Scale to max 40px
370
+ opacity = 0.4 + (count / 500) * 0.5
371
+ cells.append(f'''
372
+ <div style="display: flex; align-items: center; gap: 4px; height: 18px;">
373
+ <div style="font-size: 8px; color: #64748b; width: 14px; text-align: right;">{gene}</div>
374
+ <div style="width: 44px; height: 12px; background: #e2e8f0; border-radius: 2px; overflow: hidden; position: relative;">
375
+ <div style="width: {bar_width}px; height: 100%; background: {color}; opacity: {opacity:.1f}; border-radius: 2px;"></div>
376
+ </div>
377
+ <div style="font-size: 8px; color: #475569; font-family: monospace; width: 22px;">{count}</div>
378
+ </div>
379
+ ''')
380
  return '\n'.join(cells)
381
 
382
  def reset_inference(prompt_cell, prompt_cond, query_cell):