Adrian Gabriel commited on
Commit
e6d3534
·
1 Parent(s): 20d014e

fix multiple bugs

Browse files
Files changed (4) hide show
  1. app.py +25 -2
  2. instrumentation.py +55 -41
  3. static/index.html +220 -6
  4. tracer.py +4 -0
app.py CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
2
 
3
  import ast
4
  import asyncio
 
5
  import os
6
  import queue
 
7
  import traceback
8
  from pathlib import Path
9
  from typing import Any, Dict
@@ -205,6 +207,22 @@ def _make_exec_env(tracer: Tracer) -> Dict[str, Any]:
205
  }
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def _run_user_code(code: str, tracer: Tracer) -> None:
209
  # 1. Transform code to auto-capture variable names
210
  transformed_code = transform_code(code)
@@ -212,14 +230,19 @@ def _run_user_code(code: str, tracer: Tracer) -> None:
212
  # 2. Setup Environment
213
  env = _make_exec_env(tracer)
214
 
215
- # 3. Instrument Tensor/Layer classes to talk to our tracer
 
 
 
 
216
  with Instrumentor(tracer):
217
  try:
218
- # 4. Execute transformed code
219
  exec(transformed_code, env, env)
220
  except Exception:
221
  tracer.error(traceback.format_exc())
222
  finally:
 
223
  tracer.done()
224
 
225
 
 
2
 
3
  import ast
4
  import asyncio
5
+ import io
6
  import os
7
  import queue
8
+ import sys
9
  import traceback
10
  from pathlib import Path
11
  from typing import Any, Dict
 
207
  }
208
 
209
 
210
+ class PrintCapture(io.StringIO):
211
+ """Captures print output and sends it to the tracer."""
212
+ def __init__(self, tracer: Tracer):
213
+ super().__init__()
214
+ self.tracer = tracer
215
+
216
+ def write(self, text: str) -> int:
217
+ # Send non-empty text to tracer
218
+ if text and text.strip():
219
+ self.tracer.print(text.rstrip('\n'))
220
+ return len(text)
221
+
222
+ def flush(self):
223
+ pass
224
+
225
+
226
  def _run_user_code(code: str, tracer: Tracer) -> None:
227
  # 1. Transform code to auto-capture variable names
228
  transformed_code = transform_code(code)
 
230
  # 2. Setup Environment
231
  env = _make_exec_env(tracer)
232
 
233
+ # 3. Redirect stdout to capture print statements
234
+ old_stdout = sys.stdout
235
+ sys.stdout = PrintCapture(tracer)
236
+
237
+ # 4. Instrument Tensor/Layer classes to talk to our tracer
238
  with Instrumentor(tracer):
239
  try:
240
+ # 5. Execute transformed code
241
  exec(transformed_code, env, env)
242
  except Exception:
243
  tracer.error(traceback.format_exc())
244
  finally:
245
+ sys.stdout = old_stdout
246
  tracer.done()
247
 
248
 
instrumentation.py CHANGED
@@ -102,48 +102,62 @@ class Instrumentor:
102
  setattr(Tensor, method_name, wrapped)
103
 
104
  def _wrap_layer_forward(self):
105
- """Wraps Layer.__call__ to emit trace events for layer operations."""
106
- if not hasattr(Layer, '__call__'):
107
- return
108
-
109
- original = Layer.__call__
110
- self.original_methods["Layer.__call__"] = original
111
  instrumentor = self # Capture reference for closure
112
-
113
- def wrapped(instance, x, *args, **kwargs):
114
- # Set flag to suppress internal tensor op tracing
115
- was_inside = instrumentor._inside_layer
116
- instrumentor._inside_layer = True
117
- try:
118
- # Execute original logic
119
- result = original(instance, x, *args, **kwargs)
120
- finally:
121
- instrumentor._inside_layer = was_inside
122
-
123
- # Get the layer name
124
- layer_name = instance.__class__.__name__
125
-
126
- # Build inputs list - for Linear, include weight and bias
127
- inputs = [x]
128
- meta = {'layer_type': layer_name}
129
-
130
- # For Linear layers, include weight and bias for visualization
131
- if layer_name == 'Linear':
132
- if hasattr(instance, 'weight'):
133
- #inputs.append(instance.weight)
134
- inputs.insert(0, instance.weight)
135
-
136
- meta['has_weight'] = True
137
- if hasattr(instance, 'bias') and instance.bias is not None:
138
- inputs.append(instance.bias)
139
- meta['has_bias'] = True
140
-
141
- # Emit op event for the layer
142
- instrumentor.tracer.op(layer_name.lower(), inputs, result, meta)
143
-
144
- return result
145
-
146
- setattr(Layer, "__call__", wrapped)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def _wrap_activation(self, activation_cls):
149
  """Wraps an activation class's __call__ and forward methods."""
 
102
  setattr(Tensor, method_name, wrapped)
103
 
104
  def _wrap_layer_forward(self):
105
+ """Wraps Layer.__call__ and concrete layer forward methods."""
 
 
 
 
 
106
  instrumentor = self # Capture reference for closure
107
+
108
+ # Import Linear here to wrap its forward method specifically
109
+ from tinytorch.core.layers import Linear, Dropout
110
+
111
+ def make_layer_wrapper(layer_cls):
112
+ def make_wrapped(orig):
113
+ def wrapped(instance, x, *args, **kwargs):
114
+ # Set flag to suppress internal tensor op tracing
115
+ was_inside = instrumentor._inside_layer
116
+ instrumentor._inside_layer = True
117
+ try:
118
+ # Execute original logic
119
+ result = orig(instance, x, *args, **kwargs)
120
+ finally:
121
+ instrumentor._inside_layer = was_inside
122
+
123
+ # Only emit event if this is the outermost call (prevent double-tracing)
124
+ if not was_inside:
125
+ # Get the layer name
126
+ layer_name = instance.__class__.__name__
127
+
128
+ # Build inputs list - for Linear, include weight and bias
129
+ inputs = [x]
130
+ meta = {'layer_type': layer_name}
131
+
132
+ # For Linear layers, include weight and bias for visualization
133
+ if layer_name == 'Linear':
134
+ if hasattr(instance, 'weight'):
135
+ inputs.insert(0, instance.weight)
136
+ meta['has_weight'] = True
137
+ if hasattr(instance, 'bias') and instance.bias is not None:
138
+ inputs.append(instance.bias)
139
+ meta['has_bias'] = True
140
+
141
+ # Emit op event for the layer
142
+ instrumentor.tracer.op(layer_name.lower(), inputs, result, meta)
143
+
144
+ return result
145
+ return wrapped
146
+ return make_wrapped
147
+
148
+ # Wrap Layer.__call__
149
+ if hasattr(Layer, '__call__'):
150
+ original_call = Layer.__call__
151
+ self.original_methods["Layer.__call__"] = original_call
152
+ setattr(Layer, "__call__", make_layer_wrapper(Layer)(original_call))
153
+
154
+ # Wrap concrete layer forward methods (Linear, Dropout, etc.)
155
+ for layer_cls in [Linear, Dropout]:
156
+ if hasattr(layer_cls, 'forward'):
157
+ cls_name = layer_cls.__name__
158
+ original_forward = layer_cls.forward
159
+ self.original_methods[f"{cls_name}.forward"] = original_forward
160
+ setattr(layer_cls, "forward", make_layer_wrapper(layer_cls)(original_forward))
161
 
162
  def _wrap_activation(self, activation_cls):
163
  """Wraps an activation class's __call__ and forward methods."""
static/index.html CHANGED
@@ -107,6 +107,90 @@
107
  #run:hover { background: #2563eb; }
108
  #run:disabled { background: #475569; cursor: not-allowed; }
109
  #error { color: #ef4444; font-size: 12px; min-height: 20px; font-family: monospace; white-space: pre-wrap; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  #pdfSidebar {
112
  flex: 1 1 340px;
@@ -533,6 +617,32 @@
533
  color: #ff4444;
534
  text-shadow: 0 0 4px rgba(255, 68, 68, 0.5);
535
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  /* Status dot glow */
538
  body.retro-mode .status-dot {
@@ -612,6 +722,14 @@ box("Layer 3: Linear + Softmax", [a2, z3, y_pred], "4")
612
  box("Loss Computation", [y_pred, target_probs, loss], "6")
613
  </textarea>
614
  <button id="run">Run</button>
 
 
 
 
 
 
 
 
615
  <div id="error"></div>
616
  </div>
617
  <div class="v-resizer" data-left="editor"></div>
@@ -724,6 +842,7 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
724
  nextGroupId = 1;
725
  nextBoxId = 1;
726
  document.getElementById('error').textContent = '';
 
727
  break;
728
  case 'tensor':
729
  tensors[msg.id] = { id: msg.id, shape: msg.shape, data: msg.data, name: msg.name };
@@ -745,12 +864,30 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
745
  break;
746
  case 'error':
747
  document.getElementById('error').textContent = msg.message;
 
 
 
 
748
  break;
749
  case 'done':
750
  layoutAndRender();
751
  break;
752
  }
753
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
  // ==================== Layout ====================
756
  function layoutAndRender() {
@@ -1495,6 +1632,44 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
1495
  });
1496
  }
1497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1498
  // ==================== Matmul Hover Highlighting ====================
1499
  function setupMatmulHover(wrapper, leftCard, topCard, resultCard) {
1500
  const resultTable = resultCard.querySelector('.matrix-table');
@@ -1703,7 +1878,7 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
1703
 
1704
  if (isMatmul) {
1705
  // Matrix multiplication: keep grid layout for proper alignment
1706
- const grid = document.createElement('div');
1707
  grid.className = 'layout-grid layout-binary';
1708
 
1709
  const left = createMatrixCard(inputs[0].name || inputs[0].id, inputs[0], 'auto', 'left');
@@ -1748,7 +1923,7 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
1748
  }
1749
 
1750
  const inputCards = [];
1751
- if (inputs[0]) {
1752
  const inputLabel = getInputDisplayName(inputs[0]);
1753
  const inputCard = createMatrixCard(inputLabel, inputs[0], 'auto', 'input');
1754
  grid.appendChild(inputCard);
@@ -1757,10 +1932,16 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
1757
  const outputCard = createMatrixCard(output.name || type, output, outputOrientation, 'output');
1758
  grid.appendChild(outputCard);
1759
 
1760
- // Add element-wise hover highlighting for activations and element-wise ops
1761
- const isElementwiseUnary = ['relu', 'sigmoid', 'tanh', 'gelu', 'softmax', 'logsoftmax'].includes(type);
1762
- if (isElementwiseUnary && inputCards.length > 0) {
1763
- setupElementwiseHover(inputCards, outputCard);
 
 
 
 
 
 
1764
  }
1765
 
1766
  container.appendChild(grid);
@@ -2172,6 +2353,39 @@ box("Loss Computation", [y_pred, target_probs, loss], "6")
2172
  document.getElementById('code').onkeydown = e => {
2173
  if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { e.preventDefault(); runCode(); }
2174
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2175
 
2176
  // ==================== Init ====================
2177
  // Restore retro mode preference
 
107
  #run:hover { background: #2563eb; }
108
  #run:disabled { background: #475569; cursor: not-allowed; }
109
  #error { color: #ef4444; font-size: 12px; min-height: 20px; font-family: monospace; white-space: pre-wrap; }
110
+
111
+ /* Console output area */
112
+ #console-container {
113
+ margin-top: 12px;
114
+ border: 1px solid #334155;
115
+ border-radius: 6px;
116
+ background: #0a0f1a;
117
+ overflow: hidden;
118
+ display: flex;
119
+ flex-direction: column;
120
+ min-height: 60px;
121
+ height: 150px;
122
+ flex-shrink: 0;
123
+ }
124
+ #console-resizer {
125
+ height: 6px;
126
+ background: transparent;
127
+ cursor: ns-resize;
128
+ position: relative;
129
+ flex-shrink: 0;
130
+ }
131
+ #console-resizer::after {
132
+ content: '';
133
+ position: absolute;
134
+ left: 50%;
135
+ top: 50%;
136
+ transform: translate(-50%, -50%);
137
+ width: 40px;
138
+ height: 3px;
139
+ background: #475569;
140
+ border-radius: 2px;
141
+ transition: background 0.2s;
142
+ }
143
+ #console-resizer:hover::after {
144
+ background: #64748b;
145
+ }
146
+ #console-header {
147
+ display: flex;
148
+ justify-content: space-between;
149
+ align-items: center;
150
+ padding: 6px 10px;
151
+ background: #1e293b;
152
+ border-bottom: 1px solid #334155;
153
+ font-size: 11px;
154
+ font-weight: 600;
155
+ color: #94a3b8;
156
+ letter-spacing: 0.5px;
157
+ }
158
+ #console-clear {
159
+ background: none;
160
+ border: none;
161
+ color: #64748b;
162
+ cursor: pointer;
163
+ font-size: 12px;
164
+ padding: 2px 6px;
165
+ border-radius: 3px;
166
+ }
167
+ #console-clear:hover {
168
+ background: #334155;
169
+ color: #f1f5f9;
170
+ }
171
+ #console-output {
172
+ flex: 1;
173
+ overflow-y: auto;
174
+ padding: 8px 10px;
175
+ font-family: 'Fira Code', 'JetBrains Mono', 'Consolas', monospace;
176
+ font-size: 12px;
177
+ line-height: 1.5;
178
+ color: #e2e8f0;
179
+ white-space: pre-wrap;
180
+ word-break: break-word;
181
+ }
182
+ #console-output .console-line {
183
+ margin: 2px 0;
184
+ }
185
+ #console-output .console-line.error {
186
+ color: #f87171;
187
+ }
188
+ #console-output .console-line.success {
189
+ color: #4ade80;
190
+ }
191
+ #console-output .console-line.info {
192
+ color: #60a5fa;
193
+ }
194
 
195
  #pdfSidebar {
196
  flex: 1 1 340px;
 
617
  color: #ff4444;
618
  text-shadow: 0 0 4px rgba(255, 68, 68, 0.5);
619
  }
620
+
621
+ /* Console in retro mode */
622
+ body.retro-mode #console-resizer::after {
623
+ background: rgba(0, 255, 65, 0.3);
624
+ }
625
+ body.retro-mode #console-resizer:hover::after {
626
+ background: rgba(0, 255, 65, 0.6);
627
+ }
628
+ body.retro-mode #console-container {
629
+ border-color: rgba(0, 255, 65, 0.3);
630
+ background: rgba(0, 10, 5, 0.8);
631
+ }
632
+ body.retro-mode #console-header {
633
+ background: rgba(0, 40, 20, 0.6);
634
+ border-bottom-color: rgba(0, 255, 65, 0.2);
635
+ color: #00ff41;
636
+ text-shadow: 0 0 4px rgba(0, 255, 65, 0.5);
637
+ }
638
+ body.retro-mode #console-output {
639
+ color: #00ff41;
640
+ text-shadow: 0 0 2px rgba(0, 255, 65, 0.3);
641
+ }
642
+ body.retro-mode #console-output .console-line.error {
643
+ color: #ff4444;
644
+ text-shadow: 0 0 4px rgba(255, 68, 68, 0.5);
645
+ }
646
 
647
  /* Status dot glow */
648
  body.retro-mode .status-dot {
 
722
  box("Loss Computation", [y_pred, target_probs, loss], "6")
723
  </textarea>
724
  <button id="run">Run</button>
725
+ <div id="console-resizer"></div>
726
+ <div id="console-container">
727
+ <div id="console-header">
728
+ <span>CONSOLE OUTPUT</span>
729
+ <button id="console-clear" title="Clear console">✕</button>
730
+ </div>
731
+ <div id="console-output"></div>
732
+ </div>
733
  <div id="error"></div>
734
  </div>
735
  <div class="v-resizer" data-left="editor"></div>
 
842
  nextGroupId = 1;
843
  nextBoxId = 1;
844
  document.getElementById('error').textContent = '';
845
+ clearConsole();
846
  break;
847
  case 'tensor':
848
  tensors[msg.id] = { id: msg.id, shape: msg.shape, data: msg.data, name: msg.name };
 
864
  break;
865
  case 'error':
866
  document.getElementById('error').textContent = msg.message;
867
+ appendToConsole(msg.message, 'error');
868
+ break;
869
+ case 'print':
870
+ appendToConsole(msg.text, msg.type || 'info');
871
  break;
872
  case 'done':
873
  layoutAndRender();
874
  break;
875
  }
876
  }
877
+
878
+ // ==================== Console Output ====================
879
+ function appendToConsole(text, type = 'info') {
880
+ const consoleOutput = document.getElementById('console-output');
881
+ const line = document.createElement('div');
882
+ line.className = 'console-line ' + type;
883
+ line.textContent = text;
884
+ consoleOutput.appendChild(line);
885
+ consoleOutput.scrollTop = consoleOutput.scrollHeight;
886
+ }
887
+
888
+ function clearConsole() {
889
+ document.getElementById('console-output').innerHTML = '';
890
+ }
891
 
892
  // ==================== Layout ====================
893
  function layoutAndRender() {
 
1632
  });
1633
  }
1634
 
1635
+ // ==================== Transpose Hover Highlighting ====================
1636
+ // For transpose, output[i,j] corresponds to input[j,i]
1637
+ function setupTransposeHover(inputCard, outputCard) {
1638
+ const outputTable = outputCard.querySelector('.matrix-table');
1639
+ const inputTable = inputCard.querySelector('.matrix-table');
1640
+
1641
+ if (!outputTable || !inputTable) return;
1642
+
1643
+ const outputCells = outputTable.querySelectorAll('td[data-row][data-col]');
1644
+
1645
+ outputCells.forEach(cell => {
1646
+ cell.style.cursor = 'pointer';
1647
+
1648
+ cell.addEventListener('mouseenter', () => {
1649
+ const row = cell.dataset.row;
1650
+ const col = cell.dataset.col;
1651
+
1652
+ // Highlight the output cell
1653
+ cell.classList.add('highlight-cell');
1654
+
1655
+ // For transpose: input[col, row] -> output[row, col]
1656
+ // So to find the source, swap row and col
1657
+ const inputCell = inputTable.querySelector(`td[data-row="${col}"][data-col="${row}"]`);
1658
+ if (inputCell) {
1659
+ inputCell.classList.add('highlight-cell');
1660
+ }
1661
+ });
1662
+
1663
+ cell.addEventListener('mouseleave', () => {
1664
+ // Remove all highlights
1665
+ cell.classList.remove('highlight-cell');
1666
+ inputTable.querySelectorAll('td.highlight-cell').forEach(td => {
1667
+ td.classList.remove('highlight-cell');
1668
+ });
1669
+ });
1670
+ });
1671
+ }
1672
+
1673
  // ==================== Matmul Hover Highlighting ====================
1674
  function setupMatmulHover(wrapper, leftCard, topCard, resultCard) {
1675
  const resultTable = resultCard.querySelector('.matrix-table');
 
1878
 
1879
  if (isMatmul) {
1880
  // Matrix multiplication: keep grid layout for proper alignment
1881
+ const grid = document.createElement('div');
1882
  grid.className = 'layout-grid layout-binary';
1883
 
1884
  const left = createMatrixCard(inputs[0].name || inputs[0].id, inputs[0], 'auto', 'left');
 
1923
  }
1924
 
1925
  const inputCards = [];
1926
+ if (inputs[0]) {
1927
  const inputLabel = getInputDisplayName(inputs[0]);
1928
  const inputCard = createMatrixCard(inputLabel, inputs[0], 'auto', 'input');
1929
  grid.appendChild(inputCard);
 
1932
  const outputCard = createMatrixCard(output.name || type, output, outputOrientation, 'output');
1933
  grid.appendChild(outputCard);
1934
 
1935
+ // Add hover highlighting based on operation type
1936
+ if (type === 'transpose' && inputCards.length > 0) {
1937
+ // Transpose: output[i,j] corresponds to input[j,i]
1938
+ setupTransposeHover(inputCards[0], outputCard);
1939
+ } else {
1940
+ // Element-wise activations: output[i,j] corresponds to input[i,j]
1941
+ const isElementwiseUnary = ['relu', 'sigmoid', 'tanh', 'gelu', 'softmax', 'logsoftmax'].includes(type);
1942
+ if (isElementwiseUnary && inputCards.length > 0) {
1943
+ setupElementwiseHover(inputCards, outputCard);
1944
+ }
1945
  }
1946
 
1947
  container.appendChild(grid);
 
2353
  document.getElementById('code').onkeydown = e => {
2354
  if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { e.preventDefault(); runCode(); }
2355
  };
2356
+ document.getElementById('console-clear').onclick = clearConsole;
2357
+
2358
+ // ==================== Console Resizer ====================
2359
+ (function() {
2360
+ const resizer = document.getElementById('console-resizer');
2361
+ const consoleContainer = document.getElementById('console-container');
2362
+ let startY, startHeight;
2363
+
2364
+ resizer.addEventListener('mousedown', (e) => {
2365
+ e.preventDefault();
2366
+ startY = e.clientY;
2367
+ startHeight = consoleContainer.offsetHeight;
2368
+
2369
+ document.addEventListener('mousemove', onMouseMove);
2370
+ document.addEventListener('mouseup', onMouseUp);
2371
+ document.body.style.cursor = 'ns-resize';
2372
+ document.body.style.userSelect = 'none';
2373
+ });
2374
+
2375
+ function onMouseMove(e) {
2376
+ // Dragging up increases height (startY - e.clientY is positive when moving up)
2377
+ const deltaY = startY - e.clientY;
2378
+ const newHeight = Math.max(60, Math.min(400, startHeight + deltaY));
2379
+ consoleContainer.style.height = newHeight + 'px';
2380
+ }
2381
+
2382
+ function onMouseUp() {
2383
+ document.removeEventListener('mousemove', onMouseMove);
2384
+ document.removeEventListener('mouseup', onMouseUp);
2385
+ document.body.style.cursor = '';
2386
+ document.body.style.userSelect = '';
2387
+ }
2388
+ })();
2389
 
2390
  // ==================== Init ====================
2391
  // Restore retro mode preference
tracer.py CHANGED
@@ -176,5 +176,9 @@ class Tracer:
176
  def error(self, message: str) -> None:
177
  self.sink.emit(TraceEvent("error", {"message": str(message)}).asdict())
178
 
 
 
 
 
179
  def done(self) -> None:
180
  self.sink.emit(TraceEvent("done", {}).asdict())
 
176
  def error(self, message: str) -> None:
177
  self.sink.emit(TraceEvent("error", {"message": str(message)}).asdict())
178
 
179
+ def print(self, text: str, msg_type: str = "info") -> None:
180
+ """Emit a print event for console output."""
181
+ self.sink.emit(TraceEvent("print", {"text": str(text), "type": msg_type}).asdict())
182
+
183
  def done(self) -> None:
184
  self.sink.emit(TraceEvent("done", {}).asdict())