// Experimental import * as base from './base.js'; import * as flatbuffers from './flatbuffers.js'; import * as python from './python.js'; const pytorch = {}; const nnapi = {}; const numpy = {}; pytorch.ModelFactory = class { async match(context) { const reader = await pytorch.Reader.open(context); if (reader) { return context.set(reader.type, reader); } return null; } filter(context, match) { if (context.type === 'pytorch.export' && match.type === 'pytorch.zip') { return false; } if (context.type === 'pytorch.index' && match.type === 'pytorch.zip') { return false; } if (context.type === 'pytorch.model.json' && match.type === 'pytorch.data.pkl') { return false; } if (context.type === 'pytorch.model.json' && match.type === 'pickle') { return false; } return true; } async open(context) { const metadata = await pytorch.Metadata.open(context); const target = context.value; target.on('resolve', (sender, name) => { context.error(new pytorch.Error(`Unknown type name '${name}'.`), false); }); await target.read(metadata); if (!target.format || (!target.modules && !target.module)) { throw new pytorch.Error("Reader not implemented."); } return new pytorch.Model(metadata, target); } }; pytorch.Model = class { constructor(metadata, target) { this.format = target.format; this.producer = target.producer || ''; this.modules = []; if (target.module) { const graph = new pytorch.Graph(target.execution, metadata, null, '', target.module); this.modules.push(graph); delete target.execution; } else if (target.modules) { for (const [name, value] of target.modules) { const graph = new pytorch.Graph(target.execution, metadata, null, name, value); this.modules.push(graph); delete target.execution; } } } }; pytorch.Graph = class { constructor(execution, metadata, type, name = '', module = null) { this.nodes = []; this.inputs = []; this.outputs = []; this.name = name; this.type = type; const context = new pytorch.Context(execution, metadata); context.values.map = (name, type, tensor) => { if (tensor) { return new pytorch.Value(name, type, null, tensor); } if (!context.values.has(name)) { context.values.set(name, new pytorch.Value(name, type, null, tensor)); } else if (type || tensor) { throw new pytorch.Error(`Duplicate value '${name}'.`); } return context.values.get(name); }; const torch = execution ? execution.torch : null; if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) { const initializers = new Map(); const graph = module.graph; const constants = module.code_with_constants[1].const_mapping; if (constants) { for (const [key, value] of constants) { const name = `CONSTANTS.${key}`; if (pytorch.Utility.isTensor(value)) { initializers.set(value, new pytorch.Tensor(context, name, value)); } else if (pytorch.Utility.isObject(value)) { initializers.set(value, value); } else { // throw new pytorch.Error('Unsupported constant.'); } } } const deleted = new Set(); const param_node = graph.param_node(); const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null; if (self) { const getattr = (value) => { if (value.value === undefined) { const node = value.node(); if (node.kind() === 'prim::GetAttr') { const [input] = node.inputs(); getattr(input); if (input.value !== undefined) { const name = node.s('name'); value.value = input.value.__getattr__(name); value.identifier = input.identifier ? `${input.identifier}.${name}` : name; } } if (node === param_node && value === param_node.outputs()[0]) { value.value = module; value.identifier = ''; } } }; for (const node of graph.nodes()) { for (const input of node.inputs()) { getattr(input, node); } } const delattr = (value) => { for (const use of Array.from(value.uses())) { const node = use.user; if (node.kind() === 'prim::GetAttr') { for (const output of node.outputs()) { delattr(output); } // deleted.add(node); node.destroy(); } } }; delattr(param_node.outputs()[0], ''); } for (const node of graph.nodes()) { if (node.kind() === 'prim::Constant' && node.outputs().length === 1) { const output = node.output(); output.identifier = output.debugName(); if (node.hasAttribute('value')) { const kind = node.kindOf('value'); output.value = node[kind]('value'); } else if (node.output().type() instanceof torch.NoneType) { output.value = null; } // deleted.add(node); node.destroy(); } } for (const node of graph.nodes()) { if (node.kind() === 'prim::TupleUnpack') { const value = node.inputs()[0].value; if (Array.isArray(value) && value.length === node.outputs().length && value.every((value) => typeof value === 'number' || typeof value === 'string' || typeof value === 'boolean')) { for (let i = 0; i < node.outputs().length; i++) { const output = node.outputs()[i]; output.value = value[i]; } // deleted.add(node); node.destroy(); } } } for (const node of graph.nodes()) { if (node.kind() === 'prim::ListConstruct' && node.inputs().every((value) => typeof value.value === 'number' || typeof value.value === 'string' || typeof value.value === 'boolean') && node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) { node.outputs()[0].value = node.inputs().map((value) => value.value); // deleted.add(node); node.destroy(); } } for (const v of graph.inputs()) { if (self.uses().length === 0 && v === self) { continue; } const identifier = pytorch.Utility.unique(v); const name = v.debugName() || identifier; const value = context.values.map(identifier); this.inputs.push(new pytorch.Argument(name, [value])); } for (const value of graph.outputs()) { const identifier = pytorch.Utility.unique(value); this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)])); } for (const node of graph.nodes()) { if (deleted.has(node)) { continue; } if (node === graph.param_node() || node === graph.return_node()) { continue; } if (node.kind() === 'prim::ListConstruct') { if (node.outputs().length === 1 && node.outputs().every((output) => output.uses().length === 1) && node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) { continue; } } this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context)); } } else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) { const exported_program = module; const graph = exported_program.graph; const graph_module = exported_program.graph_module; const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters; const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers; const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants; const nodes = new Map(graph.nodes.map((node) => [node.name, node])); for (const obj of graph.nodes) { if (obj.op === 'placeholder') { if (inputs_to_parameters.has(obj.name)) { const key = inputs_to_parameters.get(obj.name); const parameter = exported_program.state_dict.get(key); const tensor = parameter ? (parameter.data || parameter) : obj.meta.get('val'); const initializer = new pytorch.Tensor(context, key, tensor); const value = new pytorch.Value(key, null, null, initializer); context.values.set(obj, value); } else if (inputs_to_buffers.has(obj.name)) { const key = inputs_to_buffers.get(obj.name); const buffer = exported_program.state_dict.get(key); const tensor = buffer || obj.meta.get('val'); const initializer = new pytorch.Tensor(context, key, tensor); const value = new pytorch.Value(key, null, null, initializer); context.values.set(obj, value); } else if (inputs_to_lifted_tensor_constants.has(obj.name)) { const key = inputs_to_lifted_tensor_constants.get(obj.name); const constant = exported_program.constants.get(key); const tensor = constant && constant.data ? constant.data : obj.meta.get('val'); const initializer = new pytorch.Tensor(context, key, tensor); const value = new pytorch.Value(key, null, null, initializer); context.values.set(obj, value); } if (obj.users.size > 1 && context.values.has(obj)) { const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context); this.nodes.push(node); context.values.set(obj, node.outputs[0].value[0]); } } } context.graph(this, graph_module, false); for (const input_spec of exported_program.graph_signature.user_inputs) { if (nodes.has(input_spec)) { const node = nodes.get(input_spec); const value = context.value(node); const argument = new pytorch.Argument(input_spec, [value]); this.inputs.push(argument); } } } else if (torch && module instanceof torch.fx.GraphModule && module.graph) { context.graph(this, module, true); } else if (pytorch.Utility.isTensor(module)) { const node = new pytorch.Node(execution, metadata, null, type, { value: module }, null, context); this.nodes.push(node); } else { const weights = this.type === 'weights' ? module : pytorch.Utility.weights(module); if (weights) { this.name = !this.name && typeof module.__name__ === 'string' ? module.__name__ : this.name; for (const [name, module] of weights) { const node = new pytorch.Node(execution, metadata, name, 'Weights', module, null, context); this.nodes.push(node); } } else { const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module]; for (const module of modules) { const type = this.type === 'weights' ? 'Weights' : null; const node = new pytorch.Node(execution, metadata, null, type, module, null, context); this.nodes.push(node); } } } } }; pytorch.Argument = class { constructor(name, value, type = null, visible = true) { this.name = name; this.value = value; this.type = type; this.visible = visible; } }; pytorch.Value = class Value { constructor(name, type, quantization, initializer = null) { if (typeof name !== 'string') { throw new pytorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`); } this.name = name; this.type = initializer && initializer.type ? initializer.type : type || null; this.quantization = quantization; this.initializer = initializer; } }; pytorch.Node = class { constructor(execution, metadata, name, type, obj, initializers, context, stack) { const torch = execution ? execution.torch : null; const builtins = execution ? execution.builtins : null; this.name = name || ''; this.nodes = []; this.attributes = []; this.inputs = []; this.outputs = []; this.blocks = []; this.metadata = []; if (torch && obj instanceof torch.Node) { const node = obj; const kind = node.kind(); const schema = node.schema(); const inputs = node.inputs(); const outputs = node.outputs(); this.type = { name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0], identifier: kind }; if (schema && schema.category) { this.type.category = schema.category; } const getAttribute = (node, name) => { const kind = node.kindOf(name); let value = null; let type = null; switch (kind) { case 's': value = node.s(name); type = 'string'; break; case 'i': value = node.i(name); type = 'int64'; break; case 'f': value = node.f(name); type = 'float32'; break; case 't': value = node.t(name); type = 'tensor'; break; case 'ss': value = node.ss(name); type = 'string[]'; break; case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break; case 'ival': value = node.ival(name); break; default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`); } return [type, value]; }; for (const name of node.attributeNames()) { const [type, value] = getAttribute(node, name); const attribute = new pytorch.Argument(name, value, type); this.attributes.push(attribute); } const mapTensor = (value) => { if (value.identifier && pytorch.Utility.isTensor(value.value)) { const identifier = value.identifier; if (!context.values.has(identifier)) { const tensor = new pytorch.Tensor(context, identifier, value.value); context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor)); } return context.values.map(identifier); } let initializer = null; let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`; if (value.value) { const obj = value.value; const hide = obj.__parent__ ? obj.__parent__.__hide__ : true; initializer = hide ? initializers.get(obj) : null; identifier = initializer ? initializer.name : identifier; } if (initializer) { return new pytorch.Value(identifier, null, null, initializer); } return context.values.map(identifier); }; for (let i = 0; i < inputs.length; i++) { const input = inputs[i]; const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null; const name = arg && arg.name ? arg.name : i.toString(); let type = arg ? arg.real_type : null; let array = false; if (type instanceof torch.ListType) { array = true; type = type.getElementType(); } let argument = null; if (type && type instanceof torch.ClassType) { const obj = input.value; if (!array && initializers.has(obj)) { const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context); argument = new pytorch.Argument(name, node, 'object'); } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) { const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context)); argument = new pytorch.Argument(name, node, 'object[]'); } else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) { const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context)); argument = new pytorch.Argument(name, node, 'object[]'); } else if (input.value === undefined) { const identifier = pytorch.Utility.unique(input); const value = context.values.map(identifier); argument = new pytorch.Argument(name, [value]); } else { const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context); argument = new pytorch.Argument(name, node, 'object'); } } else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) { const value = mapTensor(input); argument = new pytorch.Argument(name, [value]); } else if (input instanceof torch.Value && !pytorch.Utility.isTensor(input.value)) { if (input.value !== undefined) { if (Array.isArray(input.value) && input.value.every((value) => pytorch.Utility.isTensor(value))) { continue; } const type = input.type() ? pytorch.Utility.toType(input.type()) : null; let value = input.value; if (value && value instanceof torch._C.IValue) { value = pytorch.Utility.toString(value); } if (value && value instanceof builtins.complex) { value = new base.Complex(value.real, value.imag); } argument = new pytorch.Argument(name, value, type || 'attribute'); } else if (input.type() instanceof torch.ListType) { if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((value) => value instanceof torch.Value || value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType || value.type() instanceof torch.TensorType)) { const list = input.node().inputs(); const args = list.map((value) => { if (pytorch.Utility.isTensor(value.value)) { return mapTensor(value); } if (value.value !== undefined) { return value.value; } const identifier = pytorch.Utility.unique(value); return context.values.map(identifier); }); const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type()); argument = new pytorch.Argument(name, args, type); } else { const identifier = pytorch.Utility.unique(input); argument = new pytorch.Argument(name, [context.values.map(identifier)]); } } else if (input.type() instanceof torch.StringType && typeof input.value === 'string') { argument = new pytorch.Argument(name, input.value, 'string'); } else if (input.type() instanceof torch.BoolType && (typeof input.value === 'boolean' || input.value === 0 || input.value === 1)) { argument = new pytorch.Argument(name, Boolean(input.value), 'boolean'); } else if (input.type() instanceof torch.IntType && typeof input.value === 'number') { argument = new pytorch.Argument(name, input.value, 'int64'); } else if (input.type() instanceof torch.FloatType && typeof input.value === 'number') { argument = new pytorch.Argument(name, input.value, 'float32'); } else if (input.type() instanceof torch.NoneType && input.value === null) { argument = new pytorch.Argument(name, null, 'attribute'); } else { const identifier = pytorch.Utility.unique(input); const value = context.values.map(identifier); argument = new pytorch.Argument(name, [value]); } } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) { let list = [input]; if (input.node() && node !== input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { list = input.node().inputs(); } const args = list.map((input) => { let initializer = null; let identifier = pytorch.Utility.unique(input); if (input.value) { const value = input.value; const hide = value.__parent__ ? value.__parent__.__hide__ : true; initializer = hide ? initializers.get(value) : null; identifier = initializer ? initializer.name : identifier; } if (initializer) { return new pytorch.Value(identifier, null, null, initializer); } return context.values.map(identifier); }); argument = new pytorch.Argument(name, args); } else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) { const args = input.value.map((value) => { if (value instanceof torch.Value) { const identifier = pytorch.Utility.unique(value); return context.values.map(identifier); } return value; }); argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type)); } else { throw new pytorch.Error('Unsupported input value'); } this.inputs.push(argument); } for (let i = 0; i < outputs.length; i++) { const output = outputs[i]; const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null; if (ret && ret.name) { name = ret.name; } else { name = i === 0 && outputs.length === 1 ? 'output' : `${i}`; } let list = [output]; if (output.uses().length === 1 && output.uses()[0].user && output.uses()[0].user.kind() === 'prim::ListUnpack' && output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) { list = output.uses()[0].user.outputs(); } const args = list.map((output) => context.values.map(pytorch.Utility.unique(output))); const argument = new pytorch.Argument(name, args); this.outputs.push(argument); } const blocks = node.blocks(); for (let i = 0; i < blocks.length; i++) { const block = blocks[i]; const nodes = Array.from(block.nodes()); if (nodes.length > 0) { const name = `block${i}`; const graph = { name, nodes: [], inputs: [], outputs: [] }; for (const v of block.inputs()) { const identifier = pytorch.Utility.unique(v); const value = context.values.map(identifier); graph.inputs.push(new pytorch.Argument(v.debugName() || identifier, [value])); } for (const v of block.outputs()) { const identifier = pytorch.Utility.unique(v); graph.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)])); } for (const n of nodes) { if (n === block.param_node() || n === block.return_node()) { continue; } graph.nodes.push(new pytorch.Node(execution, metadata, null, null, n, initializers, context)); } const argument = new pytorch.Argument(name, graph, 'graph'); this.blocks.push(argument); } } const sourceRange = node.sourceRange(); if (sourceRange) { this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''), 'attribute')); if (sourceRange.source()) { const orig = sourceRange.source().findSourceRangeThatGenerated(sourceRange); if (orig) { this.metadata.push(new pytorch.Argument('generated', orig.toString(), 'attribute')); } } } } else if (torch && obj instanceof torch.fx.node.Node) { if (obj.op === 'call_function') { let name = null; const target = obj.target; if (target instanceof torch._ops.OpOverload) { name = target.name(); } else if (target instanceof torch._ops.HigherOrderOperator) { name = `${target.namespace}::${target.name}`; } else if (builtins.isinstance(target, builtins.function)) { name = target.__name__; } else if (typeof target === 'string') { // Handle unresolved operators const match = target.match(/^torch\.ops\.([^.]+)\.(.+)$/); if (!match) { throw new pytorch.Error(`Unsupported target '${target}'.`); } const [, namespace, opname] = match; name = `${namespace}::${opname}`; } else { throw new pytorch.Error(`Unsupported target '${target}'.`); } this.type = { identifier: name, name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0] }; const schema = obj.target._schema; if (schema && schema.category) { this.type.category = schema.category; } let args = obj.args.map((arg, index) => { if (!schema) { return ['', arg]; } if (Array.isArray(schema.arguments) && index < schema.arguments.length) { return [schema.arguments[index].name, arg]; } if (schema.is_vararg) { return ['', arg]; } throw new pytorch.Error('Unsupported schema argument.'); }); const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg])); args = args.concat(Array.from(obj.kwargs)); for (const [name, arg] of args) { let type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null; if (arg instanceof torch.fx.node.Node) { let argument = null; if (arg.op === 'get_attr' && arg.users.size === 1) { const subgraph = context.function(arg); if (subgraph) { argument = new pytorch.Argument(name, subgraph, 'function'); } } if (!argument) { const value = context.value(arg); argument = new pytorch.Argument(name, [value]); } this.inputs.push(argument); } else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) { const list = arg.map((arg) => arg === null ? null : context.value(arg)); const argument = new pytorch.Argument(name, list); this.inputs.push(argument); } else if (Array.isArray(arg)) { const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg); const argument = new pytorch.Argument(name, list, type || 'attribute'); this.inputs.push(argument); } else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) { const argument = new pytorch.Argument(name, arg.toString(), type || 'attribute'); this.inputs.push(argument); } else { const primitive = typeof arg === 'number' || typeof arg === 'boolean' || typeof arg === 'string' || arg === null; type = type === 'tensor' && primitive ? null : type; const argument = new pytorch.Argument(name, arg, type || 'attribute'); this.inputs.push(argument); } } let outputs = [obj]; if (obj.users.size > 1) { const users = Array.from(obj.users.keys()); if (users.every((user) => user.op === 'call_function' && user.target.__module__ === 'operator' && user.target.__name__ === 'getitem')) { outputs = new Array(obj.users.size); for (const user of users) { const [, index] = user.args; outputs[index] = user; } } } for (let i = 0; i < outputs.length; i++) { const node = outputs[i]; const value = context.value(node); const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output'; const argument = new pytorch.Argument(name, [value]); this.outputs.push(argument); } for (const [name, value] of obj.meta) { if (name === 'val' || name === 'torch_fn' || (Array.isArray(value) && value.length === 0) || (value instanceof Map && value.size === 0)) { continue; } if (typeof value === 'string') { const argument = new pytorch.Argument(name, value, 'string'); this.metadata.push(argument); } else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) { const argument = new pytorch.Argument(name, value, 'string[]'); this.metadata.push(argument); } else if (value instanceof Map && value.size > 0) { // const argument = new pytorch.Argument(name, Object.fromEntries(Array.from(value))); // this.metadata.push(argument); } else { // const argument = new pytorch.Argument(name, value); // this.metadata.push(argument); } } } else if (obj.op === 'placeholder') { this.type = { name: obj.op }; { const value = context.value(obj); const argument = new pytorch.Argument('value', [value]); this.inputs.push(argument); } { const node = new torch.fx.node.Node(null, obj.name); node.meta = obj.meta; const value = context.value(node); const argument = new pytorch.Argument('value', [value]); this.outputs.push(argument); } } else if (obj.op === 'get_attr') { this.type = { name: obj.op }; const subgraph = context.function(obj); if (subgraph) { this.inputs.push(new pytorch.Argument('name', subgraph, 'function')); } else { this.inputs.push(new pytorch.Argument('name', obj.target, 'string')); } const value = context.value(obj); this.outputs.push(new pytorch.Argument('value', [value])); } else if (obj.op === 'root') { this.type = { name: obj.op }; } else { throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`); } } else { if (torch && obj instanceof torch.ScriptObject) { type = obj._type().qualified_name(); obj = obj._ivalue; } else if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { type = obj._c._type(); const target = { _modules: obj._modules, _parameters: obj._parameters, _buffers: obj._buffers, }; for (let i = 0; i < type.numAttributes(); i++) { if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) { const k = type.getAttributeName(i); target[k] = obj.__getattr__(k); } } type = obj._c.qualified_name; obj = target; } if (!type) { if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) { type = obj._c.qualified_name; } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) { type = `${obj.__module__}.${obj.__name__}`; obj = {}; } else if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { type = `${obj.__class__.__module__}.${obj.__class__.__name__}`; } else { type = 'builtins.object'; } } if (type instanceof nnapi.Graph) { this.type = type; } else { const key = type.startsWith('__torch__.') ? type.substring(10) : type; const value = metadata.type(key); this.type = value ? { ...value } : { name: type }; this.type.identifier = type; } stack = stack || new Set(); const weights = pytorch.Utility.weights(obj); if (weights) { const type = this.type.name; this.type = new pytorch.Graph(execution, metadata, 'weights', '', weights); this.type.name = type; } else if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) { // continue } else if (obj && pytorch.Utility.isInstance(obj, '__torch__.torch.classes._nnapi.Compilation')) { // continue } else if (obj && type === 'builtins.bytearray') { const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]'); this.inputs.push(argument); } else if (obj) { const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []); const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj); for (const [name, value] of list) { if (name === '__class__' || name === '__name__') { continue; } else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) { continue; } else if (pytorch.Utility.isInstance(value, 'builtins.set') && value instanceof Set && value.size === 0) { continue; } else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) { continue; } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) { continue; } let parameters = null; if ((name === '_parameters' || name === '_buffers') && value instanceof Map) { parameters = value; } else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) { parameters = new Map([[name, value]]); } if (parameters) { for (const [name, value] of parameters) { const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)]; const visible = inputs.has(name) ? inputs.get(name).visible || true : true; const args = list.filter((value) => value !== null && !value.__origin__).map((value) => { const name = value && value.name ? value.name : ''; const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name; let tensor = null; if (initializers && initializers.has(value)) { tensor = initializers.get(value); } else { value = value.__source__ ? value.__source__ : value; tensor = value ? new pytorch.Tensor(context, identifier, value) : null; } return new pytorch.Value(identifier, null, null, tensor); }); const argument = new pytorch.Argument(name, args, null, visible); this.inputs.push(argument); if (value && value.__variable__) { const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]); this.outputs.push(argument); } } continue; } if (pytorch.Utility.isTensor(value)) { const tensor = new pytorch.Tensor(context, '', value); const argument = new pytorch.Argument(name, tensor, 'tensor'); this.inputs.push(argument); } else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) { const node = new pytorch.Node(execution, metadata, null, value.toString(), {}, null, context); const argument = new pytorch.Argument(name, node, 'object'); this.inputs.push(argument); } else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) { const tensors = value.map((value) => value === null ? value : new pytorch.Tensor(context, '', value)); const argument = new pytorch.Argument(name, tensors, 'tensor[]'); this.inputs.push(argument); } else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) { const tensor = new numpy.Tensor(value); const argument = new pytorch.Argument(name, tensor, 'tensor'); this.inputs.push(argument); } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) { const argument = new pytorch.Argument(name, value, 'string[]'); this.inputs.push(argument); } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) { const argument = new pytorch.Argument(name, value, 'attribute'); this.inputs.push(argument); } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) { const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => { stack.add(value); const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`; const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack); stack.delete(value); return node; }); const argument = new pytorch.Argument(name, list, 'object[]'); this.inputs.push(argument); } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) { const argument = new pytorch.Argument(name, value, 'attribute'); this.inputs.push(argument); } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) { const list = value.filter((value) => !stack.has(value)); const nodes = list.map((value) => { stack.add(value); const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack); stack.delete(value); return node; }); const argument = new pytorch.Argument(name, nodes, 'object[]'); this.inputs.push(argument); } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) { stack.add(value); const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack); stack.delete(value); const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value); const argument = new pytorch.Argument(name, node, 'object', visible); this.inputs.push(argument); } else { let schema = metadata.attribute(this.type.identifier, name); schema = name === 'training' ? { type: 'boolean', visible: false } : schema; let visible = true; let obj = value; const type = schema && schema.type ? schema.type : 'attribute'; if (schema) { if (schema.visible === false) { visible = false; } else if (schema.default !== undefined) { if (Array.isArray(obj)) { if (Array.isArray(schema.default)) { visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]); } else { visible = !obj.every((item) => item === schema.default); } } else { visible = obj !== schema.default; } } } if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) { obj = '?'; } const argument = new pytorch.Argument(name, obj, type, visible); this.inputs.push(argument); } } } } } }; pytorch.Tensor = class { constructor(context, name, tensor) { this.name = name || ''; this.attributes = []; tensor = tensor.data ? tensor.data : tensor; const storage = tensor.storage(); this.type = context.type(tensor); const layout = this.type.layout; const size = this.type.shape.dimensions || []; if (layout) { this.indices = new pytorch.Tensor(context, '', tensor.indices); this._values = new pytorch.Tensor(context, '', tensor.values); } else { this.encoding = '<'; this.indices = null; this.stride = tensor.stride(); const stride = this.stride; const offset = tensor.storage_offset(); if (storage) { this._data = storage.data; let length = 0; if (!Array.isArray(stride)) { length = storage.size(); } else if (size.every((v) => v !== 0)) { length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1); } if (storage && typeof storage.size === 'function') { if (offset !== 0 || length !== storage.size()) { const itemsize = storage.dtype.itemsize(); this._offset = itemsize * offset; this._length = itemsize * length; } } } } const type = tensor.__class__ || {}; if (type.tensor_attribute_names) { for (const name of type.tensor_attribute_names) { let value = tensor[name]; if (value !== undefined) { if (value && typeof value.__reduce__ === 'function') { value = value.__reduce__(); } const attribute = new pytorch.Argument(name, value, 'attribute'); this.attributes.push(attribute); } } } if (type.tensor_data_names) { for (const name of type.tensor_data_names) { const value = tensor[name]; if (value !== undefined && pytorch.Utility.isTensor(value)) { const attribute = new pytorch.Argument(name, new pytorch.Tensor(context, name, value), 'tensor'); this.attributes.push(attribute); } } } } get values() { const type = this.type.layout; if (type && type.startsWith('sparse.')) { return this._values; } if (this._data instanceof Uint8Array) { return this._data; } if (this._data && this._offset !== undefined) { const stream = this._data; const position = stream.position; stream.seek(this._offset); const values = stream.peek(this._length); stream.seek(position); return values; } if (this._data) { return this._data.peek(); } return null; } }; pytorch.TensorType = class { constructor(dataType, shape, layout) { this.dataType = dataType; this.shape = shape; this.layout = layout; } toString() { return this.dataType + this.shape.toString(); } }; pytorch.TensorShape = class { constructor(dimensions = []) { this.dimensions = dimensions; } toString() { if (this.dimensions && this.dimensions.length > 0) { return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`; } return ''; } }; pytorch.Context = class { constructor(execution, metadata) { this.execution = execution; this.torch = execution ? execution.__import__('torch') : null; this.metadata = metadata; this.values = new Map(); this.modules = new Map(); } type(tensor) { let dataType = tensor.dtype.__reduce__(); switch (dataType) { case 'float8_e5m2': dataType = 'float8e5m2'; break; case 'float8_e5m2fnuz': dataType = 'float8e5m2fnuz'; break; case 'float8_e4m3fn': dataType = 'float8e4m3fn'; break; case 'float8_e4m3fnuz': dataType = 'float8e4m3fnuz'; break; case 'float8_e8m0fnu': dataType = 'float8e8m0fnu'; break; case 'float4_e2m1fn_x2': dataType = 'float4e2m1fnx2'; break; default: break; } const size = tensor.size ? tensor.size() : tensor.shape; const shape = new pytorch.TensorShape(size || []); const layout = tensor.layout ? tensor.layout.__str__() : null; if (layout && layout.startsWith('torch.sparse_')) { return new pytorch.TensorType(dataType, shape, layout.split('.').pop().replace('_', '.')); } return new pytorch.TensorType(dataType, shape); } value(obj) { const torch = this.torch; if (obj instanceof torch.fx.node.Node) { if (!this.values.has(obj)) { let type = null; const val = obj.meta ? obj.meta.get('val') : null; if (val && val.dtype) { type = this.type(val); } const value = new pytorch.Value(obj.name, type); this.values.set(obj, value); } return this.values.get(obj); } return null; } function(obj) { const torch = this.torch; if (obj instanceof torch.fx.node.Node) { let subgraph = this.modules.get(obj); if (subgraph) { if (subgraph instanceof pytorch.Graph === false) { subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph); this.modules.set(obj, subgraph); } return subgraph; } } return null; } graph(target, module, inputs) { const graph = module.graph; if (module.named_modules) { const modules = module.named_modules(); for (const obj of graph.nodes) { if (obj.op === 'get_attr') { const submodule = modules.get(obj.target); if (submodule && submodule.graph) { this.modules.set(obj, submodule); } } } } let controlDependency = null; for (const obj of graph.nodes) { if (obj.op === 'placeholder') { if (inputs) { const value = this.value(obj); const argument = new pytorch.Argument(obj.name, [value]); target.inputs.push(argument); } continue; } if (obj.op === 'call_function') { if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') { continue; } } if (obj.op === 'get_attr') { if (this.modules.has(obj) && obj.users.size === 1) { continue; } } if (obj.op === 'output') { for (const output of obj.args) { if (output === null || output === undefined) { continue; } if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') { continue; } const value = this.value(output); const argument = new pytorch.Argument(output.name, [value]); target.outputs.push(argument); } continue; } const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this); target.nodes.push(node); if (controlDependency) { node.controlDependencies = node.controlDependencies || []; node.controlDependencies.push(controlDependency); controlDependency = null; } if (obj.op === 'call_function' && obj.users.size === 0) { controlDependency = node.outputs[0].value[0]; } } } }; pytorch.Reader = class { static async open(context) { const types = [ pytorch.Reader.Zip, pytorch.Reader.Pickle, pytorch.Reader.Tar, pytorch.Reader.data_pkl, pytorch.Reader.torch_utils, pytorch.Reader.Mobile, pytorch.Reader.ModelJson, pytorch.Reader.IR, pytorch.Reader.Index, pytorch.Reader.ExportedProgram ]; for (const type of types) { // eslint-disable-next-line no-await-in-loop const reader = await type.open(context); if (reader) { return reader; } } return null; } constructor() { this._events = []; } async read() { } on(event, callback) { this._events.push([event, callback]); } }; pytorch.Reader.Tar = class extends pytorch.Reader { static async open(context) { const entries = await context.peek('tar'); if (entries instanceof Map && entries.has('pickle')) { return new pytorch.Reader.Tar(entries); } return null; } constructor(entries) { super(); this.type = 'pytorch.tar'; this.entries = entries; } async read() { this.format = 'PyTorch v0.1.1'; const execution = new python.Execution(); for (const event of this._events) { execution.on(event[0], event[1]); } const torch = execution.__import__('torch'); this.module = torch.load(this.entries); delete this.entries; } }; pytorch.Reader.Pickle = class extends pytorch.Reader { static async open(context) { const stream = context.stream; const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19]; if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) { return new pytorch.Reader.Pickle(stream); } return null; } constructor(stream) { super(); this.type = 'pytorch.pickle'; this.stream = stream; } async read() { this.format = 'PyTorch v0.1.10'; const data = this.stream.length < 0x7ffff000 ? this.stream.peek() : this.stream; delete this.stream; const execution = new python.Execution(); for (const event of this._events) { execution.on(event[0], event[1]); } const torch = execution.__import__('torch'); this.module = torch.load(data); } }; pytorch.Reader.data_pkl = class extends pytorch.Reader { static async open(context) { const obj = await context.peek('pkl'); if (obj) { if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) { const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`; if (name.startsWith('__torch__.')) { return new pytorch.Reader.data_pkl('', obj); } } if (pytorch.Utility.isTensor(obj)) { return new pytorch.Reader.data_pkl('tensor', obj); } if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) { return new pytorch.Reader.data_pkl('tensor', obj); } if (obj instanceof Map) { const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value)); if (entries.length > 0) { return new pytorch.Reader.data_pkl('tensor', obj); } } else if (!Array.isArray(obj)) { const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value)); if (entries.length > 0) { return new pytorch.Reader.data_pkl('tensor', obj); } } for (const key of ['', 'model', 'net']) { const module = key === '' ? obj : obj[key]; if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) { return new pytorch.Reader.data_pkl('module', module); } } } return null; } constructor(type, module) { super(); this.type = 'pytorch.data.pkl'; this.format = 'PyTorch Pickle'; this.module = module; } async read() { } }; pytorch.Reader.torch_utils = class extends pytorch.Reader { static async open(context) { const stream = context.stream; if (stream && stream.length > 1) { const buffer = stream.peek(Math.min(1024, stream.length)); if (buffer[0] === 0x80) { const content = String.fromCharCode.apply(null, buffer); if (content.indexOf('torch_utils') !== -1) { const obj = await context.peek('pkl'); if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) { return new pytorch.Reader.torch_utils(obj); } } } } return null; } constructor(obj) { super(); this.type = 'pytorch.torch_utils'; this.obj = obj; } async read() { this.format = 'PyTorch torch_utils'; this.module = this.obj; delete this.obj; } }; pytorch.Reader.Mobile = class extends pytorch.Reader { static async open(context) { const reader = await context.peek('flatbuffers.binary'); if (reader && reader.identifier === 'PTMF') { return new pytorch.Reader.Mobile(context); } return null; } constructor(context) { super(); this.type = 'pytorch.mobile'; this.context = context; } async read(metadata) { const execution = new pytorch.Execution(null, metadata); for (const event of this._events) { execution.on(event[0], event[1]); } const stream = this.context.stream; const torch = execution.__import__('torch'); torch.mobile = await this.context.require('./pytorch-schema'); torch.mobile = torch.mobile.torch.jit.mobile; this.module = torch.jit.jit_module_from_flatbuffer(stream); const version = this.module._c._bytecode_version.toString(); this.format = pytorch.Utility.format('PyTorch Mobile', version); delete this.context; } }; pytorch.Reader.Zip = class extends pytorch.Reader { static async open(context) { const entries = await context.peek('zip'); if (entries instanceof Map && entries.size > 0) { let prefix = 0; const paths = Array.from(entries.keys()).map((path) => path.replace(/\\/g, '/').split('/').reverse()); for (let set = new Set(); set && paths.length > 0;) { set = new Set(paths.map((path) => path.length > 1 ? path.pop() : null)); set = set.size > 1 || set.keys().next().value === null ? null : set; prefix += set ? set.keys().next().value.length + 1 : 0; } const records = new Map(Array.from(entries).map(([name, value]) => [name.substring(prefix), value])); if (records.has('model.json')) { return null; } if (records.has('data.pkl')) { return new pytorch.Reader.Zip(entries); } if (records.has('.data/version') && !records.has('archive_format')) { return new pytorch.Reader.Package(entries); } } return null; } constructor(entries) { super(); this.type = 'pytorch.zip'; // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md this._entries = entries; } async read(metadata) { this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } const torch = this.execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this._entries); let torchscript = reader.has_record('constants.pkl'); const version = reader.version(); if (torchscript) { metadata.register(this.execution); this.module = torch.jit.load(reader); torchscript = this.module._c._has_method('forward'); if (torchscript) { // console.log(this.module.graph.toString()); torch._C._jit_pass_inline(this.module.graph); // console.log(this.module.graph.toString()); } } else { const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]); const entries = new Map(records); this.module = torch.load(entries); } const name = torchscript ? 'TorchScript' : 'PyTorch'; this.format = pytorch.Utility.format(name, version); delete this._model; delete this._entries; } }; pytorch.Reader.ModelJson = class extends pytorch.Reader { static async open(context) { const identifier = context.identifier; if (identifier === 'model.json') { const model = await context.peek('json'); if (model && model.mainModule) { const entries = new Map(); entries.set('model.json', context.stream); return new pytorch.Reader.ModelJson(context, entries, model); } } return null; } constructor(context, entries, model) { super(); this.type = 'pytorch.model.json'; this._context = context; this._entries = entries; this._model = model; } async read(metadata) { pytorch.proto = await this._context.require('./pytorch-proto'); const keys = [ 'attributes.pkl', 'version', ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key) ]; const walk = (module) => { if (module.torchscriptArena && module.torchscriptArena.key) { keys.push(module.torchscriptArena.key); } for (const submodule of module.submodules || []) { walk(submodule); } }; walk(this._model.mainModule); const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null))); for (let i = 0; i < keys.length; i++) { if (values[i]) { this._entries.set(keys[i], values[i]); } } this.execution = new pytorch.Execution(null, metadata); this.execution.proto = pytorch.proto; for (const event of this._events) { this.execution.on(event[0], event[1]); } const torch = this.execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this._entries); if (this._model && this._model.producerName) { this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : ''); } this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0'; metadata.register(this.execution); this.module = torch.jit.load(reader); if (this.module._c._has_method('forward')) { // console.log(this.module.graph.toString()); torch._C._jit_pass_inline(this.module.graph); // console.log(this.module.graph.toString()); } delete this._context; delete this._model; delete this._entries; } }; pytorch.Reader.IR = class extends pytorch.Reader { static async open(context) { const reader = await context.read('text', 0x100); if (reader && reader.length > 0) { const line = reader.read('\n'); if (line.startsWith('graph(')) { return new pytorch.Reader.IR(context); } } return null; } constructor(context) { super(); this.type = 'pytorch.ir'; this.context = context; } async read(metadata) { this.format = 'TorchScript IR'; this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } // this.execution.graph; // context reader = await context.read('text', 0x100); throw new pytorch.Error('TorchScript IR parser not implemented.'); } }; pytorch.Reader.Index = class extends pytorch.Reader { static async open(context) { const obj = await context.peek('json'); if (obj && obj.weight_map) { const entries = Object.entries(obj.weight_map); if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) { return new pytorch.Reader.Index(context, entries); } } return null; } constructor(context, entries) { super(); this.type = 'pytorch.index'; this.context = context; this._entries = entries; } async read(metadata) { this.format = 'PyTorch'; const weight_map = new Map(this._entries); const keys = new Set(weight_map.keys()); const files = Array.from(new Set(weight_map.values())); const contexts = await Promise.all(files.map((name) => this.context.fetch(name))); this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } const torch = this.execution.__import__('torch'); const archives = await Promise.all(contexts.map((context) => context.peek('zip'))); const formats = new Set(archives.map((entries) => { const reader = new torch.PyTorchFileReader(entries); const version = reader.version(); return pytorch.Utility.format('PyTorch', version); })); if (formats.size === 1) { this.format = formats.values().next().value; } const shards = archives.map((entries) => { return torch.load(entries); }); const entries = new Map(); for (const shard of shards) { for (const [key, value] of Array.from(shard)) { if (keys.has(key)) { entries.set(key, value); } } } this.module = entries; delete this.context; delete this._entries; } }; pytorch.Reader.ExportedProgram = class extends pytorch.Reader { static async open(context) { const program = await context.peek('json'); if (program && program.schema_version && program.graph_module) { return new pytorch.Reader.ExportedProgram(context, program); } if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) { const buffer = context.stream.peek(); const archive_format = String.fromCharCode.apply(null, buffer); if (archive_format === 'pt2') { return new pytorch.Reader.ExportedProgram(context, null, context); } } return null; } constructor(context, exported_program, archive_format) { super(); this.type = 'pytorch.export'; this.context = context; this.archive_format = archive_format; this.exported_program = exported_program; } async read(metadata) { this.format = 'PyTorch Export'; const f = new Map(); const exported_programs = new Map(); if (this.archive_format) { for (const name of this.context.container.entries.keys()) { const match = name.match(/^models\/([^/]+)\.json$/); if (match) { const [, model_name] = match; /* eslint-disable no-await-in-loop */ const model = await this.context.fetch(`models/${model_name}.json`); const exported_program = await model.read('json'); exported_programs.set(model_name, exported_program); f.set(`models/${model_name}.json`, exported_program); const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`, 'zip'); f.set(`data/sample_inputs/${model_name}.pt`, sample_inputs); const weights_config = await this._fetch(`data/weights/${model_name}_weights_config.json`, 'json'); if (weights_config) { f.set(`data/weights/${model_name}_weights_config.json`, weights_config); for (const payload_meta of Object.values(weights_config.config)) { const type = payload_meta.use_pickle ? 'zip' : 'binary'; const weight_data = await this._fetch(`data/weights/${payload_meta.path_name}`, type); if (weight_data) { f.set(`data/weights/${payload_meta.path_name}`, weight_data); } } } else { const weights = await this._fetch(`data/weights/${model_name}.pt`, 'zip'); f.set(`data/weights/${model_name}.pt`, weights); } const constants_config = await this._fetch(`data/constants/${model_name}_constants_config.json`, 'json'); if (constants_config) { f.set(`data/constants/${model_name}_constants_config.json`, constants_config); for (const payload_meta of Object.values(constants_config.config)) { // eslint-enable no-await-in-loop const type = payload_meta.use_pickle ? 'zip' : 'binary'; const constant_data = await this._fetch(`data/constants/${payload_meta.path_name}`, type); if (constant_data) { f.set(`data/constants/${payload_meta.path_name}`, constant_data); } } } else { const constants = await this._fetch(`data/constants/${model_name}.pt`); f.set(`data/constants/${model_name}.pt`, constants); } /* eslint-enable no-await-in-loop */ } } const byteorder = await this._fetch('byteorder', 'text') || 'little'; f.set('byteorder', byteorder); } else { this.version = await this._fetch('version', 'text') || ''; this.version = this.version.split('\n').shift().trim(); const weights = await this._fetch('serialized_state_dict.pt', 'zip') || await this._fetch('serialized_state_dict.json', 'zip'); const constants = await this._fetch('serialized_constants.pt', 'zip') || await this._fetch('serialized_constants.json', 'zip'); const sample_inputs = await this._fetch('serialized_example_inputs.pt', 'zip'); f.set('models/model.json', this.exported_program); f.set('data/weights/model.pt', weights); f.set('data/constants/model.pt', constants); f.set('data/sample_inputs/model.pt', sample_inputs); exported_programs.set('', this.exported_program); } if (!this.version) { const versions = new Set(); for (const exported_program of exported_programs.values()) { const schema_version = exported_program.schema_version; if (schema_version && schema_version.major && schema_version.minor) { versions.add(`${schema_version.major}.${schema_version.minor}`); } } if (versions.size === 1) { this.version = versions.values().next().value; } } this.format = this.version ? `${this.format} v${this.version}` : this.format; this.execution = new python.Execution(); for (const event of this._events) { this.execution.on(event[0], event[1]); } metadata.register(this.execution); const torch = this.execution.__import__('torch'); for (const exported_program of exported_programs.values()) { if (exported_program.graph_module.graph.constants) { // eslint-disable-next-line no-await-in-loop const zip = await import('./zip.js'); const constants = exported_program.graph_module.graph.constants; for (const key of Object.keys(constants)) { const value = constants[key]; const str = atob(value); const buffer = new Uint8Array(str.length); for (let i = 0; i < str.length; i++) { buffer[i] = str.charCodeAt(i); } const archive = zip.Archive.open(buffer); constants[key] = archive.entries; } } } delete this.exported_program; delete this.context; const pt2_contents = torch.export.pt2_archive._package.load_pt2(f); this.modules = pt2_contents.exported_programs; } async _fetch(name, type) { try { const context = await this.context.fetch(name); if (context) { switch (type) { case 'zip': return await context.peek('zip'); case 'json': return await context.read('json'); case 'text': { const reader = await context.read('text'); if (reader) { return reader.read(); } break; } case 'binary': { if (context && context.stream) { return context.stream.peek(); } break; } default: { throw new pytorch.Error(`Unsupported context type '${type}.`); } } } } catch { // continue regardless of error } return null; } }; pytorch.Execution = class extends python.Execution { constructor(sources, metadata) { super(sources); this._metadata = metadata; // eslint-disable-next-line consistent-this const execution = this; const torch = this.torch; this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => { const cu = new torch.jit.CompilationUnit(); cu.execution = execution; const stream = f; const reader = flatbuffers.BinaryReader.open(stream); const module = torch.mobile.serialization.Module.create(reader); const loader = new torch._C.FlatBuffersLoader(cu); const cpp_module = loader.parseModule(module); // parse_and_initialize_jit_module // const mobilem = parse_and_initialize_mobile_module_for_jit(data, jit_files, jit_constants); // const m = jitModuleFromSourceAndConstants(mobilem._ivalue(), jit_files, jit_constants, mobilem.bytecode_version()); // throw new pytorch.Error('torch.jit.mobile.serialization.Module not supported.'); return torch.jit._script.wrap_cpp_module(cpp_module); }); this.registerType('__torch__.torch.classes._nnapi.Compilation', class { constructor() { this.__hide__ = true; } __init__() { } init(serialized_model_tensor, parameter_buffers) { this.serialized_model_tensor = serialized_model_tensor; this.parameter_buffers = parameter_buffers; const buffers = parameter_buffers.map((buffer) => buffer.__source__.storage()); /* let buffers = []; if (!pytorch.Utility.isInstance(parameter_buffers, 'torch.Value')) { buffers = parameter_buffers.map((buffer) => buffer.__source__.storage()); } */ const serialized_model = serialized_model_tensor.storage().data; this.serialized_model = new nnapi.SerializedModel(serialized_model, buffers); } run(inputs, outputs) { execution.variable(this.serialized_model_tensor); this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1; const type = new nnapi.Graph(this.serialized_model); const node = execution.graph.create(type, 0); execution.graph.insertNode(node); for (const tensor of inputs) { const value = execution.variable(tensor); node.addInput(value); } for (const tensor of outputs) { execution.variable(tensor, node); } } }); this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class { __setstate__(state) { if (state[0] !== '2') { throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`); } const [/* pack_version */, tensors, opt_tensors] = state; const packed_config = tensors[0].tolist(); this.weight = tensors[1]; this.bias = opt_tensors[0]; this.stride = [packed_config[1], packed_config[2]]; this.padding = [packed_config[3], packed_config[4]]; this.dilation = [packed_config[5], packed_config[6]]; this.output_padding = [packed_config[7], packed_config[8]]; this.groups = packed_config[9]; } }); this.registerType('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', class { __setstate__(state) { if (state[0] !== '2') { throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`); } const [/* pack_version */, tensors, opt_tensors] = state; const packed_config = tensors[0].tolist(); this.weight = tensors[1]; this.bias = opt_tensors[0]; this.stride = [packed_config[1], packed_config[2]]; this.padding = [packed_config[3], packed_config[4]]; this.dilation = [packed_config[5], packed_config[6]]; this.output_padding = [packed_config[7], packed_config[8]]; this.groups = packed_config[9]; } }); this.registerType('__torch__.torch.classes.quantized.LinearPackedParamsBase', class { __setstate__(state) { [this.weight, this.bias] = state; } }); this.registerType('__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', class { __setstate__(state) { [this.version, this.tensors, this.doubles, this.longs] = state; } }); this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class { __setstate__(state) { [this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state; } }); this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class { __setstate__(state) { [this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state; } }); this.registerType('__torch__.torch.classes.xnnpack.LinearOpContext', class { __setstate__(state) { [this.weight, this.bias, this.output_min, this.output_max] = state; } }); this.registerType('__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', class { __setstate__(state) { [this.weight, this.bias, this.stride, this.padding, this.output_padding, this.dilation, this.groups, this.output_min, this.output_max] = state; } }); this.registerType('__torch__.torch.classes.tensorrt.Engine', class { __setstate__(state) { [this.abi_target, this.name, this.device, this.engine, this.input_binding_names, this.output_binding_names, this.hw_compatible, this.serialized_metadata, this.target_platform] = state; } }); const custom_classes = [ { name: '__torch__.torch.classes._nnapi.Compilation', methods: [ '__init__(__torch__.torch.classes._nnapi.Compilation self) -> NoneType', 'init(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> NoneType', 'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType', 'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType' ] }, { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase self) -> ((Tensor, Tensor?))'] }, { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase self) -> ((Tensor, Tensor?))'] }, { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' }, { name: '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', attributes: 'int version, Tensor[] tensors, float[] doubles, int[] longs', methods: [] }, { name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' }, { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' }, { name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' }, { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' }, { name: '__torch__.torch.classes.tensorrt.Engine' } ]; for (const known_type of custom_classes) { const prefix = new torch._C.QualifiedName(known_type.name); const type = torch.ClassType.create(known_type.name, this._compilation_unit, false); for (const known_method of known_type.methods || []) { const schema = new torch.FunctionSchema(known_method); const name = new torch._C.QualifiedName(prefix, schema.name); const fn = new torch._C.BuiltinOpFunction(name, schema); type.addMethod(fn); } if (known_type.attributes) { const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`); for (const arg of schema.arguments) { type.addAttribute(arg.name, arg.real_type); } } torch._C.registerCustomClass(type); } } call(target, name, args, keywords, context) { const ast = this.ast; const torch = this.torch; if (target instanceof ast.Name && target.id === 'torch') { const fn = torch.ops.aten.__getattr__(name); if (fn) { const evalArgs = args.map((arg) => this.expression(arg, context)); return fn.__call__(...evalArgs); } } if (target instanceof ast.Attribute && target.value instanceof ast.Name && target.value.id === 'ops') { const module = torch.ops[target.attr]; if (!module) { throw new pytorch.Error(`Unknown torch.ops module '${target.attr}'.`); } const fn = module.__getattr__(name); if (fn) { const evalArgs = args.map((arg) => this.expression(arg, context)); return fn.__call__(...evalArgs); } } return super.call(target, name, args, keywords, context); } invoke(target, args) { if (target && Array.isArray(target.__bases__) && target.__bases__.length > 0 && target.__bases__[0] === this.enum.Enum) { const instance = new target(); instance.value = args; return instance; } return super.invoke(target, args); } base(expr, context) { const ast = this.ast; if (expr instanceof ast.Name) { switch (expr.id) { case 'Enum': return this.enum.Enum; default: break; } } return this.expression(expr, context); } }; pytorch.Reader.Package = class extends pytorch.Reader { constructor(entries) { super(); this.type = 'pytorch.package'; this.entries = entries; } async read(metadata) { this.execution = new pytorch.Execution(null, metadata); for (const event of this._events) { this.execution.on(event[0], event[1]); } const torch = this.execution.__import__('torch'); const reader = new torch.PyTorchFileReader(this.entries); const version = reader.version(); this.format = pytorch.Utility.format('PyTorch Package', version); this.modules = new Map(); const records = reader.get_all_records().filter((name) => { if (!name.startsWith('.data/') && !name.endsWith('.py')) { const stream = reader.get_record(name); if (stream && stream.length > 2) { const signature = stream.peek(2); if (signature[0] === 0x80 && signature[1] < 7) { return true; } } } return false; }); const entries = records.map((name) => { const parts = name.split('/'); const resource = parts.pop(); const module = parts.join('.'); return [module, resource]; }); if (entries.length > 0) { for (const name of reader.get_all_records()) { if (!name.startsWith('.data/') && name.endsWith('.py')) { const stream = reader.get_record(name); const buffer = stream.peek(); this.execution.add(name, buffer); } } metadata.register(this.execution); const importer = new torch.package.PackageImporter(reader); for (const entry of entries) { const module = importer.load_pickle(entry[0], entry[1]); const key = `${entry[0].replace(/\./, '/')}/${entry[1]}`; this.modules.set(key, module); } } delete this.entries; } }; pytorch.MemoryFormat = { Contiguous: 0, Preserve: 1, ChannelsLast: 2, ChannelsLast3d: 3 }; pytorch.Layout = { Strided: 0, Sparse: 1, Mkldnn: 2 }; pytorch.Utility = class { static isTensor(obj) { const name = obj && obj.__class__ ? obj.__class__.__module__ : null; switch (name) { case 'torch': case 'torch.cuda': return obj.__class__.__name__.endsWith('Tensor'); case 'torch.nn.parameter': return obj.__class__.__name__ === 'Parameter'; default: return false; } } static toTensor(obj) { const name = obj && obj.__class__ ? obj.__class__.__module__ : null; switch (name) { case 'torch': case 'torch.cuda': return obj.__class__.__name__.endsWith('Tensor') ? obj : null; case 'torch.nn.parameter': if (obj.__class__.__name__ === 'Parameter') { const data = obj.data; if (typeof obj.__name__ === 'string') { data.__name__ = obj.__name__; } return data; } return null; default: return null; } } static toType(type) { switch (type.kind()) { case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`; case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`; case 'BoolType': return 'boolean'; case 'IntType': return 'int64'; case 'FloatType': return 'float32'; case 'StringType': return 'string'; case 'ComplexType': return 'complex'; case 'NumberType': return 'scalar'; case 'TensorType': return 'tensor'; case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`; case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`; case 'DeviceObjType': return 'device'; case 'SymIntType': return 'SymInt'; case 'ScalarTypeType': return 'ScalarType'; case 'MemoryFormat': return 'MemoryFormat'; case 'Layout': return 'Layout'; case 'VarType': return type.annotation_str; case 'NoneType': return 'None'; case 'AnyType': return 'object'; case 'AnyListType': return 'list'; case 'AnyTupleType': return 'tuple'; case 'ClassType': return type.annotation_str; case 'EnumType': return type.annotation_str; default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`); } } static toString(ivalue) { if (ivalue.isInt()) { return ivalue.toInt(); } if (ivalue.isDouble()) { return ivalue.toDouble(); } if (ivalue.isEnum()) { return ivalue.toEnumHolder().name(); } if (ivalue.isList()) { return ivalue.toList().map((item) => pytorch.Utility.toString(item)); } throw new pytorch.Error(`Unsupported IValue '${ivalue.tag}.`); } static constant(node, name) { const kind = node.kindOf(name); switch (kind) { case 's': return node.s(name); case 'i': return node.i(name); case 'f': return node.f(name); case 'ss': return node.ss(name); case 'ival': return node.ival(name); default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`); } } static unique(value) { return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`; } static isObject(obj) { const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; switch (type) { case '__torch__.torch.classes.xnnpack.LinearOpContext': case '__torch__.torch.classes.xnnpack.Conv2dOpContext': case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext': case '__torch__.torch.classes.rnn.CellParamsBase': case '__torch__.torch.classes.rnn.CellParamsBase[]': case '__torch__.torch.classes.quantized.LinearPackedParamsBase': case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': case '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase': return true; default: return false; } } static isSubclass(value, name) { if (value && value.__module__ && value.__name__) { return name === `${value.__module__}.${value.__name__}`; } else if (value && value.__bases__) { return value.__bases__.some((obj) => pytorch.Utility.isSubclass(obj, name)); } return false; } static isInstance(value, name) { return value && value.__class__ ? pytorch.Utility.isSubclass(value.__class__, name) : false; } static format(name, value) { // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h // kProducedFileFormatVersion const versions = new Map([ ['1', 'v1.3'], ['2', 'v1.5'], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122) ['3', 'v1.6'], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085) ['4', 'v1.6'], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620) ['5', 'v1.7'], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364) ['6', 'v1.9'], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714) ['7', 'v1.10'], // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651) ['8', 'v1.11'], // b28e696516a7f0c7a6ead6da967590ce6c1d6698 (#71486) ['9', 'v1.11'], // 8757e21c6a4fc00e83539aa7f9c28eb11eff53c1 (#72051) ['10', 'v1.12'] // 4f8b986e28736b59bc46cd0873a0f36fdaa6f5b8 (#61439) ]); value = value.toString(); if (!versions.has(value)) { throw new pytorch.Error(`Unsupported '${name}' version '${value}'.`); } return `${name} ${versions.get(value)}`; } static weights(obj) { let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; if (type === 'torch.jit._script.RecursiveScriptModule') { type = obj._c._type(); const target = {}; for (let i = 0; i < type.numAttributes(); i++) { const k = type.getAttributeName(i); target[k] = obj.__getattr__(k); } type = obj._c.qualified_name; obj = target; } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') { return null; } if (pytorch.Utility.isTensor(obj)) { return null; } if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) { const entries = Object.entries(obj); const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pytorch.Utility.isTensor(value)); if (named.length > 0 && (named.length / entries.length) >= 0.8) { obj = new Map(entries); } } if (obj instanceof Map) { const entries = Array.from(obj).filter(([name]) => name !== '_metadata'); const names = entries.filter(([name]) => typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)); if (names.length > 1 && (names.length / entries.length) >= 0.8 && (entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) && (!entries.every(([, value]) => Array.isArray(value)))) { const modules = new Map(); for (const [name, value] of entries) { const separator = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.'; const path = name.split(separator); let property = path.pop(); if (path.length > 1 && path[path.length - 1] === '_packed_params') { property = `${path.pop()}.${property}`; } const key = path.join(separator); if (!modules.has(key)) { modules.set(key, {}); } const module = modules.get(key); if (pytorch.Utility.isTensor(value)) { value.__name__ = name; } module[property] = value; } return modules; } } if (obj && !Array.isArray(obj) && Object(obj) === obj) { const modules = new Map(); const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj); if (entries.length > 0 && entries) { for (const [key, value] of entries) { const name = key.toString(); if (!value || Object(value) !== value || pytorch.Utility.isTensor(value) || ArrayBuffer.isView(value) || value._modules instanceof Map) { return null; } if (!modules.has(name)) { modules.set(name, {}); } const module = modules.get(name); let tensor = false; const entries = value instanceof Map ? value : new Map(Object.entries(value)); for (const [name, value] of entries) { if (typeof name !== 'string') { return null; } if (name.indexOf('.') !== -1) { return null; } if (name === '_metadata') { continue; } if (typeof value === 'string' || typeof value === 'number') { module[name] = value; continue; } if (pytorch.Utility.isTensor(value)) { value.__name__ = name; module[name] = value; tensor = true; } } if (!tensor) { return null; } } return modules; } } return null; } static isMetadataObject(obj) { if (pytorch.Utility.isInstance(obj, 'collections.OrderedDict')) { for (const value of obj.values()) { if (pytorch.Utility.isInstance(value, 'builtins.dict')) { const entries = Array.from(value); if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) { return false; } } } return true; } return false; } }; nnapi.SerializedModel = class { constructor(serialized_model, buffers) { const reader = base.BinaryReader.open(serialized_model); this.version = reader.int32(); if (this.version !== 1) { throw new pytorch.Error('Invalid NNAPI serialized model version.'); } const operands = new Array(reader.int32()); const values = new Array(reader.int32()); this.operations = new Array(reader.int32()); this.inputs = new Array(reader.int32()); this.outputs = new Array(reader.int32()); const data_types = new Map([ [0, 'float32'], [1, 'int32'], [2, 'uint32'], [3, 'float32[]'], [4, 'int32[]'], [5, 'quant8_asymm[]'], [6, 'boolean'], [7, 'quant16_symm[]'], [8, 'float16[]'], [9, 'boolean[]'], [10, 'float16'], [11, 'quant8_symm_per_channel[]'], [12, 'quant16_asymm[]'], [13, 'quant8_symm[]'], [14, 'quant8_asymm_signed[]'], [16, 'model'] ]); for (let i = 0; i < operands.length; i++) { const data_type = reader.int32(); operands[i] = { index: i, data_type: data_types.has(data_type) ? data_types.get(data_type) : data_type, dimensions: new Array(reader.uint32()), scale: reader.float32(), zero_point: reader.int32() }; } for (let i = 0; i < values.length; i++) { values[i] = { index: reader.int32(), source_type: reader.int32(), source_length: reader.uint32() }; } for (let i = 0; i < this.operations.length; i++) { this.operations[i] = { index: reader.int32(), identifier: i, inputs: new Array(reader.uint32()), outputs: new Array(reader.uint32()) }; } for (const operand of operands) { for (let i = 0; i < operand.dimensions.length; i++) { operand.dimensions[i] = reader.uint32(); } } for (const value of values) { const index = value.index; const operand = operands[index]; switch (value.source_type) { case 0: { // immediate switch (operand.data_type) { case 'boolean': operand.value = reader.byte() ? true : false; reader.skip(3); break; case 'int32': operand.value = reader.int32(); break; case 'float32': operand.value = reader.float32(); break; case 'int32[]': operand.data = reader.read(value.source_length); break; case 'float32[]': operand.data = reader.read(value.source_length); break; default: throw new pytorch.Error(`Unsupported NNAPI operand type '${operand.data_type}'.`); } break; } case 2: { // numbered buffer if (value.source_length !== 12) { throw new pytorch.Error('Invalid NNAPI numbered buffer source length.'); } const number = reader.uint32(); const offset = reader.uint32(); const operand_length = reader.uint32(); if (number < buffers.length && buffers[number].data) { const storage = buffers[number]; const data = storage.data && storage.data.peek ? storage.data.peek() : storage.data; operand.data = data.slice(offset, operand_length); } break; } case 3: { // numbered memory throw new pytorch.Error('NNAPI numbered memory buffer not implemented.'); } default: { throw new pytorch.Error('Unsupported NNAPI value source type.'); } } } for (const operation of this.operations) { for (let i = 0; i < operation.inputs.length; i++) { const index = reader.uint32(); operation.inputs[i] = operands[index]; } for (let i = 0; i < operation.outputs.length; i++) { const index = reader.uint32(); operation.outputs[i] = operands[index]; } } for (let i = 0; i < this.inputs.length; i++) { const index = reader.uint32(); this.inputs[i] = operands[index]; } for (let i = 0; i < this.outputs.length; i++) { const index = reader.uint32(); this.outputs[i] = operands[index]; } if (reader.position !== reader.length) { throw new pytorch.Error('Invalid NNAPI serialized model length.'); } } }; nnapi.Graph = class { constructor(model) { this.name = 'torch.classes._nnapi.Compilation'; this.nodes = []; this.inputs = []; this.outputs = []; const values = new Map(); values.map = (operand) => { if (!values.has(operand.index)) { const name = operand.index.toString(); const dimensions = operand.dimensions; const shape = new pytorch.TensorShape(dimensions); let dataType = operand.data_type.replace('[]', ''); let quantization = null; switch (dataType) { case 'quant8_asymm': case 'quant8_symm_per_channel': case 'quant8_symm': case 'quant8_asymm_signed[]': case 'quant16_asymm': case 'quant16_symm': quantization = dataType; dataType = dataType.indexOf('16') === -1 ? 'uint8' : 'uint16'; break; default: break; } const type = new pytorch.TensorType(dataType, shape); let initializer = null; if (operand.data) { const size = dimensions.reduce((a, b) => a * b, 1); const tensor = { dtype: { __reduce__: () => dataType }, size: () => dimensions, stride: () => null, storage_offset: () => 0, storage: () => ({ dtype: { __reduce__: () => type.dataType }, data: operand.data, size: () => size }) }; const context = new pytorch.Context(); initializer = new pytorch.Tensor(context, null, tensor); } if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) { quantization = { type: quantization || 'linear', scale: [operand.scale], offset: [operand.zero_point] }; } const value = new pytorch.Value(name, type, quantization, initializer); values.set(operand.index, value); } return values.get(operand.index); }; const metadata = new nnapi.Metadata(); for (const operation of model.operations) { const node = new nnapi.Node(metadata, operation, values); this.nodes.push(node); } for (let i = 0; i < model.inputs.length; i++) { const name = i.toString(); const operand = model.inputs[i]; const argument = new pytorch.Argument(name, [values.map(operand)]); this.inputs.push(argument); } for (let i = 0; i < model.outputs.length; i++) { const name = i.toString(); const operand = model.outputs[i]; const argument = new pytorch.Argument(name, [values.map(operand)]); this.outputs.push(argument); } } }; nnapi.Node = class { constructor(metadata, operation, values) { const signature = (operation.inputs || []).map((input) => input.data_type); this.name = ''; this.type = metadata.type(operation.index, signature); this.inputs = []; this.outputs = []; this.attributes = []; this.chain = []; if (operation.identifier !== undefined) { this.identifier = operation.identifier.toString(); } if (Array.isArray(operation.inputs)) { const inputs = this.type.inputs; for (let i = 0; i < operation.inputs.length; i++) { const name = i < inputs.length ? inputs[i].name : i.toString(); const operand = operation.inputs[i]; if (operand.dimensions.length > 0) { const value = values.map(operand); const argument = new pytorch.Argument(name, [value]); this.inputs.push(argument); } else if (name === 'activation') { const activation = new Map([[1, 19], [2, 20], [3, 21]]).get(operand.value) || 0; if (activation !== 0) { this.chain.push(new nnapi.Node(metadata, { index: activation })); } } else { const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false); this.inputs.push(attribute); } } } if (Array.isArray(operation.outputs)) { const outputs = this.type.outputs; for (let i = 0; i < operation.outputs.length; i++) { const name = i < outputs.length ? outputs[i].name : i.toString(); const operand = operation.outputs[i]; const value = values.map(operand); const argument = new pytorch.Argument(name, [value]); this.outputs.push(argument); } } } }; nnapi.Metadata = class { constructor() { this._types = new Map(); // https://developer.android.com/ndk/reference/group/neural-networks // https://github.com/pytorch/pytorch/commits/master/torch/backends/_nnapi/serializer.py this.register(0, 'ADD', '', ['A', 'B'], [['activation', 'int32']], ['C']); this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']); this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']); this.register(2, 'CONCATENATION'); this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']); this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']); this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']); this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']); this.register(5, 'DEPTH_TO_SPACE'); this.register(6, 'DEQUANTIZE'); this.register(7, 'EMBEDDING_LOOKUP'); this.register(8, 'FLOOR'); this.register(9, 'FULLY_CONNECTED', 'Layer', ['input', 'weights', 'bias'], [['activation', 'int32']], ['output']); this.register(10, 'HASHTABLE_LOOKUP'); this.register(11, 'L2_NORMALIZATION'); this.register(12, 'L2_POOL_2D', 'Pool'); this.register(13, 'LOCAL_RESPONSE_NORMALIZATION'); this.register(14, 'LOGISTIC'); this.register(15, 'LSH_PROJECTION'); this.register(16, 'LSTM', 'Layer'); this.register(17, 'MAX_POOL_2D', 'Pool'); this.register(18, 'MUL'); this.register(19, 'RELU', 'Activation', ['input'], [], ['output']); this.register(20, 'RELU1', 'Activation'); this.register(21, 'RELU6', 'Activation'); this.register(22, 'RESHAPE', 'Shape', ['input', 'shape'], [], ['output']); this.register(23, 'RESIZE_BILINEAR'); this.register(24, 'RNN', 'Layer'); this.register(25, 'SOFTMAX', 'Activation'); this.register(26, 'SPACE_TO_DEPTH'); this.register(27, 'SVDF'); this.register(28, 'TANH'); this.register(29, 'BATCH_TO_SPACE_ND'); this.register(30, 'DIV'); this.register(31, 'MEAN'); this.register(32, 'PAD'); this.register(33, 'SPACE_TO_BATCH_ND'); this.register(34, 'SQUEEZE'); this.register(35, 'STRIDED_SLICE'); this.register(36, 'SUB'); this.register(37, 'TRANSPOSE'); this.register(38, 'ABS'); this.register(39, 'ARGMAX'); this.register(40, 'ARGMIN'); this.register(41, 'AXIS_ALIGNED_BBOX_TRANSFORM'); this.register(42, 'BIDIRECTIONAL_SEQUENCE_LSTM'); this.register(43, 'BIDIRECTIONAL_SEQUENCE_RNN'); this.register(44, 'BOX_WITH_NMS_LIMIT'); this.register(45, 'CAST'); this.register(46, 'CHANNEL_SHUFFLE'); this.register(47, 'DETECTION_POSTPROCESSING'); this.register(48, 'EQUAL'); this.register(49, 'EXP'); this.register(50, 'EXPAND_DIMS'); this.register(51, 'GATHER'); this.register(52, 'GENERATE_PROPOSALS'); this.register(53, 'GREATER'); this.register(54, 'GREATER_EQUAL'); this.register(55, 'GROUPED_CONV_2D'); this.register(56, 'HEATMAP_MAX_KEYPOINT'); this.register(57, 'INSTANCE_NORMALIZATION'); this.register(58, 'LESS'); this.register(59, 'LESS_EQUAL'); this.register(60, 'LOG'); this.register(61, 'LOGICAL_AND'); this.register(62, 'LOGICAL_NOT'); this.register(63, 'LOGICAL_OR'); this.register(64, 'LOG_SOFTMAX'); this.register(65, 'MAXIMUM'); this.register(66, 'MINIMUM'); this.register(67, 'NEG'); this.register(68, 'NOT_EQUAL'); this.register(69, 'PAD_V2'); this.register(70, 'POW'); this.register(71, 'PRELU'); this.register(72, 'QUANTIZE'); this.register(73, 'QUANTIZED_16BIT_LSTM'); this.register(74, 'RANDOM_MULTINOMIAL'); this.register(75, 'REDUCE_ALL'); this.register(76, 'REDUCE_ANY'); this.register(77, 'REDUCE_MAX'); this.register(78, 'REDUCE_MIN'); this.register(79, 'REDUCE_PROD'); this.register(80, 'REDUCE_SUM'); this.register(81, 'ROI_ALIGN'); this.register(82, 'ROI_POOLING'); this.register(83, 'RSQRT'); this.register(84, 'SELECT'); this.register(85, 'SIN'); this.register(86, 'SLICE'); this.register(87, 'SPLIT'); this.register(88, 'SQRT'); this.register(89, 'TILE'); this.register(90, 'TOPK_V2'); this.register(91, 'TRANSPOSE_CONV_2D', 'Layer'); this.register(92, 'UNIDIRECTIONAL_SEQUENCE_LSTM', 'Layer'); this.register(93, 'UNIDIRECTIONAL_SEQUENCE_RNN', 'Layer'); this.register(94, 'RESIZE_NEAREST_NEIGHBOR'); this.register(95, 'QUANTIZED_LSTM', 'Layer'); this.register(96, 'IF'); this.register(97, 'WHILE'); this.register(98, 'ELU', 'Activation'); this.register(99, 'HARD_SWISH', 'Activation'); this.register(100, 'FILL'); this.register(101, 'RANK'); } register(index, name, category, inputs, attributes, outputs) { inputs = inputs || []; outputs = outputs || []; attributes = attributes || []; const type = {}; type.name = name; type.inputs = inputs.map((name) => ({ name, type: 'Tensor' })); type.inputs = type.inputs.concat(attributes.map(([name, type]) => ({ name, type }))); type.outputs = outputs.map((name) => ({ name, type: 'Tensor' })); if (category) { type.category = category; } if (!this._types.has(index)) { this._types.set(index, []); } this._types.get(index).push(type); } type(index, signature) { if (!this._types.has(index)) { this._types.set(index, { name: index.toString(), inputs: [], outputs: [], attributes: [] }); } const types = this._types.get(index); for (const type of types) { const inputs = type.inputs; if (signature.length < inputs.length) { if (inputs.every((input, i) => input.type === undefined || input.type === 'Tensor' || input.type === signature[i])) { return type; } } } return types[0]; } }; pytorch.Metadata = class { static async open(context) { if (!pytorch.Metadata._metadata) { let data = null; try { data = await context.request('pytorch-metadata.json'); } catch { // continue regardless of error } pytorch.Metadata._metadata = new pytorch.Metadata(data); } return pytorch.Metadata._metadata; } constructor(data) { this._types = new Map(); this._attributes = new Map(); this._index = new Map(); if (data) { const items = JSON.parse(data); for (const item of items) { const index = item.name.indexOf('('); const key = index === -1 ? item.name : item.name.substring(0, index); this._types.set(key, item); } } } add(name, value) { this._types.set(name, value); } type(name) { return this._types.get(name); } attribute(type, name) { const key = `${type}:${name}`; if (!this._attributes.has(key)) { this._attributes.set(key, null); const metadata = this.type(type); if (metadata) { if (metadata.inputs) { for (const input of metadata.inputs) { this._attributes.set(`${type}:${input.name}`, input); } } if (metadata.attributes) { for (const attribute of metadata.attributes) { this._attributes.set(`${type}:${attribute.name}`, attribute); } } } } return this._attributes.get(key); } register(execution) { const torch = execution.register('torch'); const registry = torch._C.getRegistry(); const modules = new Set(); for (const [name, type] of this._types) { if (name.indexOf('::') !== -1) { const schema = torch.FunctionSchema.parse(type.name); if (type.category) { schema.category = type.category; } schema.setAliasAnalysis('FROM_SCHEMA'); const op = new torch._C.Operator(schema); registry.registerOperator(op); modules.add(type.name.split('::')[0]); } } for (const module of modules) { const namespace = new torch._ops._OpNamespace(module); execution.register(`torch.ops.${module}`, namespace); } } }; numpy.Tensor = class { constructor(array) { this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape)); this.stride = array.strides.map((stride) => stride / array.itemsize); this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder; this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes(); } }; numpy.TensorType = class { constructor(dataType, shape) { this.dataType = dataType || '?'; this.shape = shape; } toString() { return this.dataType + this.shape.toString(); } }; numpy.TensorShape = class { constructor(dimensions) { this.dimensions = dimensions; } toString() { return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : ''; } }; pytorch.Error = class extends Error { constructor(message) { super(message); this.name = 'Error loading PyTorch model.'; } }; export const Metadata = pytorch.Metadata; export const ModelFactory = pytorch.ModelFactory;