File size: 9,466 Bytes
af4939b
 
 
 
e6d3534
21c9528
af4939b
e6d3534
af4939b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ffa4e9
108a2f3
c683f70
af4939b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e707779
af4939b
 
 
 
 
 
 
 
 
 
 
 
 
 
e707779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4939b
 
 
 
 
 
 
 
 
 
0ffa4e9
af4939b
 
108a2f3
c683f70
af4939b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e707779
 
 
 
af4939b
e6d3534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4939b
e707779
 
 
af4939b
 
 
e6d3534
 
 
 
 
af4939b
 
e6d3534
e707779
af4939b
 
 
e6d3534
af4939b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21c9528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
from __future__ import annotations

import ast
import asyncio
import io
import os
import queue
import sys
import traceback
from pathlib import Path
from typing import Any, Dict

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

# Import Tracing components
from tracer import QueueSink, Tracer
from instrumentation import Instrumentor

# Import TinyTorch components
import numpy as np
from tinytorch.core.tensor import Tensor
from tinytorch.core.layers import Linear, Dropout, Layer, Sequential
from tinytorch.core.activations import ReLU, Sigmoid, Tanh, GELU, Softmax, LogSoftmax
from tinytorch.core.losses import MSELoss, CrossEntropyLoss, log_softmax
from tinytorch.core.norms import RMSNorm

# Import additional modules
from tinytorch.core.autograd import Function, enable_autograd
from tinytorch.core.optimizers import Optimizer, SGD, Adam, AdamW
from tinytorch.core.tokenization import Tokenizer, CharTokenizer, BPETokenizer, create_tokenizer, tokenize_dataset
from tinytorch.core.training import CosineSchedule, clip_grad_norm, Trainer
from tinytorch.core.embeddings import Embedding, PositionalEncoding, EmbeddingLayer, create_sinusoidal_embeddings

BASE_DIR = Path(__file__).resolve().parent
STATIC_DIR = BASE_DIR / "static"

app = FastAPI()


@app.get("/")
async def root():
    return FileResponse(STATIC_DIR / "index.html")


app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")


class AutoNameTransformer(ast.NodeTransformer):
    """
    AST transformer that automatically wraps assignments to Tensor-like values
    with a call to __auto_name__(name, value) so we can capture variable names.
    
    Transforms:
        x = Tensor([1,2,3])
    Into:
        x = __auto_name__("x", Tensor([1,2,3]))
    """
    
    def visit_Assign(self, node: ast.Assign) -> ast.AST:
        # Only handle simple single-target assignments like: x = ...
        if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
            var_name = node.targets[0].id
            
            # Skip private/dunder names
            if var_name.startswith('_'):
                return node
            
            # Wrap the value in __auto_name__(name, value)
            new_value = ast.Call(
                func=ast.Name(id='__auto_name__', ctx=ast.Load()),
                args=[
                    ast.Constant(value=var_name),
                    node.value
                ],
                keywords=[]
            )
            
            # Create new assignment with wrapped value
            new_node = ast.Assign(
                targets=node.targets,
                value=new_value
            )
            ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return new_node
        
        return node


def transform_code(code: str) -> str:
    """
    Transform user code to automatically capture variable names for Tensors.
    """
    try:
        tree = ast.parse(code)
        transformer = AutoNameTransformer()
        new_tree = transformer.visit(tree)
        ast.fix_missing_locations(new_tree)
        return ast.unparse(new_tree)
    except SyntaxError:
        # If parsing fails, return original code and let execution handle the error
        return code


def _make_exec_env(tracer: Tracer) -> Dict[str, Any]:
    """
    Execution environment for user-authored Python snippets.
    Provides direct access to TinyTorch classes and tracer utilities.
    """
    import builtins

    # Helper to allow users to manually box things
    def manual_box(label, tensors, scheme="1", parent=None):
        if not isinstance(tensors, (list, tuple)):
            tensors = [tensors]
        tracer.box(label=label, tensors=tensors, scheme=str(scheme), parent_box=parent)

    # Auto-naming helper that gets injected into transformed code
    def auto_name(name: str, value: Any) -> Any:
        """Automatically names Tensor values when they're assigned to variables."""
        if isinstance(value, Tensor):
            tracer.name(value, name)
        return value

    # Start with a clean slate but include essential builtins
    env = {}

    # Manually add critical builtins
    env['__builtins__'] = builtins.__dict__
    env['__build_class__'] = builtins.__build_class__
    env['__name__'] = '__main__'
    env['__doc__'] = None

    # Add common builtins
    for name in ['print', 'len', 'range', 'int', 'float', 'str', 'list', 'dict', 'tuple',
                 'set', 'bool', 'type', 'isinstance', 'issubclass', 'super', 'object',
                 'Exception', 'ValueError', 'TypeError', 'AttributeError', 'KeyError',
                 'zip', 'enumerate', 'map', 'filter', 'sorted', 'reversed', 'abs',
                 'min', 'max', 'sum', 'round', 'pow', 'divmod', 'hash', 'id']:
        env[name] = getattr(builtins, name)

    # Add modules
    env['math'] = __import__('math')
    env['np'] = np
    env['numpy'] = np

    # Add TinyTorch components
    tiny_torch = {
        "Tensor": Tensor,
        "Linear": Linear,
        "Dropout": Dropout,
        "Sequential": Sequential,
        "Layer": Layer,
        "ReLU": ReLU,
        "Sigmoid": Sigmoid,
        "Tanh": Tanh,
        "GELU": GELU,
        "Softmax": Softmax,
        "LogSoftmax": LogSoftmax,
        "MSELoss": MSELoss,
        "CrossEntropyLoss": CrossEntropyLoss,
        "log_softmax": log_softmax,
        "RMSNorm": RMSNorm,
        "Function": Function,
        "enable_autograd": enable_autograd,
        "Optimizer": Optimizer,
        "SGD": SGD,
        "Adam": Adam,
        "AdamW": AdamW,
        "Tokenizer": Tokenizer,
        "CharTokenizer": CharTokenizer,
        "BPETokenizer": BPETokenizer,
        "create_tokenizer": create_tokenizer,
        "tokenize_dataset": tokenize_dataset,
        "CosineSchedule": CosineSchedule,
        "clip_grad_norm": clip_grad_norm,
        "Trainer": Trainer,
        "Embedding": Embedding,
        "PositionalEncoding": PositionalEncoding,
        "EmbeddingLayer": EmbeddingLayer,
        "create_sinusoidal_embeddings": create_sinusoidal_embeddings,
        "tracer": tracer,
        "box": manual_box,
        "__auto_name__": auto_name,
    }

    env.update(tiny_torch)

    return env


class PrintCapture(io.StringIO):
    """Captures print output and sends it to the tracer."""
    def __init__(self, tracer: Tracer):
        super().__init__()
        self.tracer = tracer
        
    def write(self, text: str) -> int:
        # Send non-empty text to tracer
        if text and text.strip():
            self.tracer.print(text.rstrip('\n'))
        return len(text)
    
    def flush(self):
        pass


def _run_user_code(code: str, tracer: Tracer) -> None:
    # TEMPORARY: Skip transformation to debug
    transformed_code = transform_code(code)  # Instead of transform_code(code)

    # 2. Setup Environment
    env = _make_exec_env(tracer)

    # 3. Redirect stdout to capture print statements
    old_stdout = sys.stdout
    sys.stdout = PrintCapture(tracer)

    # 4. Instrument Tensor/Layer classes to talk to our tracer
    with Instrumentor(tracer):
        try:
            # 5. Execute transformed code
            exec(transformed_code, env)
        except Exception:
            tracer.error(traceback.format_exc())
        finally:
            sys.stdout = old_stdout
            tracer.done()


async def _stream_queue_to_ws(ws: WebSocket, q: "queue.Queue[dict | None]") -> None:
    while True:
        item = await asyncio.to_thread(q.get)
        if item is None:
            return
        await ws.send_json(item)


@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
    await ws.accept()

    try:
        while True:
            msg = await ws.receive_json()
            if not isinstance(msg, dict):
                continue

            action = msg.get("action")
            if action != "run":
                await ws.send_json({"event": "error", "message": "Unsupported action"})
                continue

            code = msg.get("code", "")
            q: "queue.Queue[dict | None]" = queue.Queue()
            tracer = Tracer(QueueSink(q))

            # Reset frontend state
            await ws.send_json({"event": "reset"})

            sender = asyncio.create_task(_stream_queue_to_ws(ws, q))

            # Run code in thread to avoid blocking async loop
            await asyncio.to_thread(_run_user_code, code, tracer)

            q.put(None)  # Signal end of stream
            await sender

    except WebSocketDisconnect:
        return


# Entry point for running the app
if __name__ == "__main__":
    import uvicorn
    
    # Support for Hugging Face Spaces (uses port 7860) and local development
    # HF Spaces sets SPACE_ID environment variable
    is_hf_space = os.environ.get("SPACE_ID") is not None
    
    # Get host and port from environment variables, with sensible defaults
    host = os.environ.get("HOST", "0.0.0.0" if is_hf_space else "127.0.0.1")
    port = int(os.environ.get("PORT", "7860" if is_hf_space else "8000"))
    
    print(f"Starting TinyTorch Visualizer on http://{host}:{port}")
    if is_hf_space:
        print("Running in Hugging Face Spaces mode")
    else:
        print("Running in local development mode")
        print(f"Open http://localhost:{port} in your browser")
    
    uvicorn.run(app, host=host, port=port)