netron / pytorch.js
shethjenil's picture
Upload 30 files
d0d9416 verified
// Experimental
import * as base from './base.js';
import * as flatbuffers from './flatbuffers.js';
import * as python from './python.js';
const pytorch = {};
const nnapi = {};
const numpy = {};
pytorch.ModelFactory = class {
async match(context) {
const reader = await pytorch.Reader.open(context);
if (reader) {
return context.set(reader.type, reader);
}
return null;
}
filter(context, match) {
if (context.type === 'pytorch.export' && match.type === 'pytorch.zip') {
return false;
}
if (context.type === 'pytorch.index' && match.type === 'pytorch.zip') {
return false;
}
if (context.type === 'pytorch.model.json' && match.type === 'pytorch.data.pkl') {
return false;
}
if (context.type === 'pytorch.model.json' && match.type === 'pickle') {
return false;
}
return true;
}
async open(context) {
const metadata = await pytorch.Metadata.open(context);
const target = context.value;
target.on('resolve', (sender, name) => {
context.error(new pytorch.Error(`Unknown type name '${name}'.`), false);
});
await target.read(metadata);
if (!target.format || (!target.modules && !target.module)) {
throw new pytorch.Error("Reader not implemented.");
}
return new pytorch.Model(metadata, target);
}
};
pytorch.Model = class {
constructor(metadata, target) {
this.format = target.format;
this.producer = target.producer || '';
this.modules = [];
if (target.module) {
const graph = new pytorch.Graph(target.execution, metadata, null, '', target.module);
this.modules.push(graph);
delete target.execution;
} else if (target.modules) {
for (const [name, value] of target.modules) {
const graph = new pytorch.Graph(target.execution, metadata, null, name, value);
this.modules.push(graph);
delete target.execution;
}
}
}
};
pytorch.Graph = class {
constructor(execution, metadata, type, name = '', module = null) {
this.nodes = [];
this.inputs = [];
this.outputs = [];
this.name = name;
this.type = type;
const context = new pytorch.Context(execution, metadata);
context.values.map = (name, type, tensor) => {
if (tensor) {
return new pytorch.Value(name, type, null, tensor);
}
if (!context.values.has(name)) {
context.values.set(name, new pytorch.Value(name, type, null, tensor));
} else if (type || tensor) {
throw new pytorch.Error(`Duplicate value '${name}'.`);
}
return context.values.get(name);
};
const torch = execution ? execution.torch : null;
if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) {
const initializers = new Map();
const graph = module.graph;
const constants = module.code_with_constants[1].const_mapping;
if (constants) {
for (const [key, value] of constants) {
const name = `CONSTANTS.${key}`;
if (pytorch.Utility.isTensor(value)) {
initializers.set(value, new pytorch.Tensor(context, name, value));
} else if (pytorch.Utility.isObject(value)) {
initializers.set(value, value);
} else {
// throw new pytorch.Error('Unsupported constant.');
}
}
}
const deleted = new Set();
const param_node = graph.param_node();
const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null;
if (self) {
const getattr = (value) => {
if (value.value === undefined) {
const node = value.node();
if (node.kind() === 'prim::GetAttr') {
const [input] = node.inputs();
getattr(input);
if (input.value !== undefined) {
const name = node.s('name');
value.value = input.value.__getattr__(name);
value.identifier = input.identifier ? `${input.identifier}.${name}` : name;
}
}
if (node === param_node && value === param_node.outputs()[0]) {
value.value = module;
value.identifier = '';
}
}
};
for (const node of graph.nodes()) {
for (const input of node.inputs()) {
getattr(input, node);
}
}
const delattr = (value) => {
for (const use of Array.from(value.uses())) {
const node = use.user;
if (node.kind() === 'prim::GetAttr') {
for (const output of node.outputs()) {
delattr(output);
}
// deleted.add(node);
node.destroy();
}
}
};
delattr(param_node.outputs()[0], '');
}
for (const node of graph.nodes()) {
if (node.kind() === 'prim::Constant' && node.outputs().length === 1) {
const output = node.output();
output.identifier = output.debugName();
if (node.hasAttribute('value')) {
const kind = node.kindOf('value');
output.value = node[kind]('value');
} else if (node.output().type() instanceof torch.NoneType) {
output.value = null;
}
// deleted.add(node);
node.destroy();
}
}
for (const node of graph.nodes()) {
if (node.kind() === 'prim::TupleUnpack') {
const value = node.inputs()[0].value;
if (Array.isArray(value) && value.length === node.outputs().length && value.every((value) => typeof value === 'number' || typeof value === 'string' || typeof value === 'boolean')) {
for (let i = 0; i < node.outputs().length; i++) {
const output = node.outputs()[i];
output.value = value[i];
}
// deleted.add(node);
node.destroy();
}
}
}
for (const node of graph.nodes()) {
if (node.kind() === 'prim::ListConstruct' &&
node.inputs().every((value) => typeof value.value === 'number' || typeof value.value === 'string' || typeof value.value === 'boolean') &&
node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) {
node.outputs()[0].value = node.inputs().map((value) => value.value);
// deleted.add(node);
node.destroy();
}
}
for (const v of graph.inputs()) {
if (self.uses().length === 0 && v === self) {
continue;
}
const identifier = pytorch.Utility.unique(v);
const name = v.debugName() || identifier;
const value = context.values.map(identifier);
this.inputs.push(new pytorch.Argument(name, [value]));
}
for (const value of graph.outputs()) {
const identifier = pytorch.Utility.unique(value);
this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
}
for (const node of graph.nodes()) {
if (deleted.has(node)) {
continue;
}
if (node === graph.param_node() ||
node === graph.return_node()) {
continue;
}
if (node.kind() === 'prim::ListConstruct') {
if (node.outputs().length === 1 &&
node.outputs().every((output) => output.uses().length === 1) &&
node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) {
continue;
}
}
this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context));
}
} else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) {
const exported_program = module;
const graph = exported_program.graph;
const graph_module = exported_program.graph_module;
const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters;
const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers;
const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants;
const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
for (const obj of graph.nodes) {
if (obj.op === 'placeholder') {
if (inputs_to_parameters.has(obj.name)) {
const key = inputs_to_parameters.get(obj.name);
const parameter = exported_program.state_dict.get(key);
const tensor = parameter ? (parameter.data || parameter) : obj.meta.get('val');
const initializer = new pytorch.Tensor(context, key, tensor);
const value = new pytorch.Value(key, null, null, initializer);
context.values.set(obj, value);
} else if (inputs_to_buffers.has(obj.name)) {
const key = inputs_to_buffers.get(obj.name);
const buffer = exported_program.state_dict.get(key);
const tensor = buffer || obj.meta.get('val');
const initializer = new pytorch.Tensor(context, key, tensor);
const value = new pytorch.Value(key, null, null, initializer);
context.values.set(obj, value);
} else if (inputs_to_lifted_tensor_constants.has(obj.name)) {
const key = inputs_to_lifted_tensor_constants.get(obj.name);
const constant = exported_program.constants.get(key);
const tensor = constant && constant.data ? constant.data : obj.meta.get('val');
const initializer = new pytorch.Tensor(context, key, tensor);
const value = new pytorch.Value(key, null, null, initializer);
context.values.set(obj, value);
}
if (obj.users.size > 1 && context.values.has(obj)) {
const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context);
this.nodes.push(node);
context.values.set(obj, node.outputs[0].value[0]);
}
}
}
context.graph(this, graph_module, false);
for (const input_spec of exported_program.graph_signature.user_inputs) {
if (nodes.has(input_spec)) {
const node = nodes.get(input_spec);
const value = context.value(node);
const argument = new pytorch.Argument(input_spec, [value]);
this.inputs.push(argument);
}
}
} else if (torch && module instanceof torch.fx.GraphModule && module.graph) {
context.graph(this, module, true);
} else if (pytorch.Utility.isTensor(module)) {
const node = new pytorch.Node(execution, metadata, null, type, { value: module }, null, context);
this.nodes.push(node);
} else {
const weights = this.type === 'weights' ? module : pytorch.Utility.weights(module);
if (weights) {
this.name = !this.name && typeof module.__name__ === 'string' ? module.__name__ : this.name;
for (const [name, module] of weights) {
const node = new pytorch.Node(execution, metadata, name, 'Weights', module, null, context);
this.nodes.push(node);
}
} else {
const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module];
for (const module of modules) {
const type = this.type === 'weights' ? 'Weights' : null;
const node = new pytorch.Node(execution, metadata, null, type, module, null, context);
this.nodes.push(node);
}
}
}
}
};
pytorch.Argument = class {
constructor(name, value, type = null, visible = true) {
this.name = name;
this.value = value;
this.type = type;
this.visible = visible;
}
};
pytorch.Value = class Value {
constructor(name, type, quantization, initializer = null) {
if (typeof name !== 'string') {
throw new pytorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
}
this.name = name;
this.type = initializer && initializer.type ? initializer.type : type || null;
this.quantization = quantization;
this.initializer = initializer;
}
};
pytorch.Node = class {
constructor(execution, metadata, name, type, obj, initializers, context, stack) {
const torch = execution ? execution.torch : null;
const builtins = execution ? execution.builtins : null;
this.name = name || '';
this.nodes = [];
this.attributes = [];
this.inputs = [];
this.outputs = [];
this.blocks = [];
this.metadata = [];
if (torch && obj instanceof torch.Node) {
const node = obj;
const kind = node.kind();
const schema = node.schema();
const inputs = node.inputs();
const outputs = node.outputs();
this.type = {
name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0],
identifier: kind
};
if (schema && schema.category) {
this.type.category = schema.category;
}
const getAttribute = (node, name) => {
const kind = node.kindOf(name);
let value = null;
let type = null;
switch (kind) {
case 's': value = node.s(name); type = 'string'; break;
case 'i': value = node.i(name); type = 'int64'; break;
case 'f': value = node.f(name); type = 'float32'; break;
case 't': value = node.t(name); type = 'tensor'; break;
case 'ss': value = node.ss(name); type = 'string[]'; break;
case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break;
case 'ival': value = node.ival(name); break;
default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
}
return [type, value];
};
for (const name of node.attributeNames()) {
const [type, value] = getAttribute(node, name);
const attribute = new pytorch.Argument(name, value, type);
this.attributes.push(attribute);
}
const mapTensor = (value) => {
if (value.identifier && pytorch.Utility.isTensor(value.value)) {
const identifier = value.identifier;
if (!context.values.has(identifier)) {
const tensor = new pytorch.Tensor(context, identifier, value.value);
context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
}
return context.values.map(identifier);
}
let initializer = null;
let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
if (value.value) {
const obj = value.value;
const hide = obj.__parent__ ? obj.__parent__.__hide__ : true;
initializer = hide ? initializers.get(obj) : null;
identifier = initializer ? initializer.name : identifier;
}
if (initializer) {
return new pytorch.Value(identifier, null, null, initializer);
}
return context.values.map(identifier);
};
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
const name = arg && arg.name ? arg.name : i.toString();
let type = arg ? arg.real_type : null;
let array = false;
if (type instanceof torch.ListType) {
array = true;
type = type.getElementType();
}
let argument = null;
if (type && type instanceof torch.ClassType) {
const obj = input.value;
if (!array && initializers.has(obj)) {
const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context);
argument = new pytorch.Argument(name, node, 'object');
} else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context));
argument = new pytorch.Argument(name, node, 'object[]');
} else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) {
const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context));
argument = new pytorch.Argument(name, node, 'object[]');
} else if (input.value === undefined) {
const identifier = pytorch.Utility.unique(input);
const value = context.values.map(identifier);
argument = new pytorch.Argument(name, [value]);
} else {
const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context);
argument = new pytorch.Argument(name, node, 'object');
}
} else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) {
const value = mapTensor(input);
argument = new pytorch.Argument(name, [value]);
} else if (input instanceof torch.Value && !pytorch.Utility.isTensor(input.value)) {
if (input.value !== undefined) {
if (Array.isArray(input.value) && input.value.every((value) => pytorch.Utility.isTensor(value))) {
continue;
}
const type = input.type() ? pytorch.Utility.toType(input.type()) : null;
let value = input.value;
if (value && value instanceof torch._C.IValue) {
value = pytorch.Utility.toString(value);
}
if (value && value instanceof builtins.complex) {
value = new base.Complex(value.real, value.imag);
}
argument = new pytorch.Argument(name, value, type || 'attribute');
} else if (input.type() instanceof torch.ListType) {
if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 &&
input.node().inputs().every((value) => value instanceof torch.Value || value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType || value.type() instanceof torch.TensorType)) {
const list = input.node().inputs();
const args = list.map((value) => {
if (pytorch.Utility.isTensor(value.value)) {
return mapTensor(value);
}
if (value.value !== undefined) {
return value.value;
}
const identifier = pytorch.Utility.unique(value);
return context.values.map(identifier);
});
const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
argument = new pytorch.Argument(name, args, type);
} else {
const identifier = pytorch.Utility.unique(input);
argument = new pytorch.Argument(name, [context.values.map(identifier)]);
}
} else if (input.type() instanceof torch.StringType && typeof input.value === 'string') {
argument = new pytorch.Argument(name, input.value, 'string');
} else if (input.type() instanceof torch.BoolType && (typeof input.value === 'boolean' || input.value === 0 || input.value === 1)) {
argument = new pytorch.Argument(name, Boolean(input.value), 'boolean');
} else if (input.type() instanceof torch.IntType && typeof input.value === 'number') {
argument = new pytorch.Argument(name, input.value, 'int64');
} else if (input.type() instanceof torch.FloatType && typeof input.value === 'number') {
argument = new pytorch.Argument(name, input.value, 'float32');
} else if (input.type() instanceof torch.NoneType && input.value === null) {
argument = new pytorch.Argument(name, null, 'attribute');
} else {
const identifier = pytorch.Utility.unique(input);
const value = context.values.map(identifier);
argument = new pytorch.Argument(name, [value]);
}
} else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
let list = [input];
if (input.node() && node !== input.node() &&
input.node().kind() === 'prim::ListConstruct' &&
input.uses().length === 1 &&
input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
list = input.node().inputs();
}
const args = list.map((input) => {
let initializer = null;
let identifier = pytorch.Utility.unique(input);
if (input.value) {
const value = input.value;
const hide = value.__parent__ ? value.__parent__.__hide__ : true;
initializer = hide ? initializers.get(value) : null;
identifier = initializer ? initializer.name : identifier;
}
if (initializer) {
return new pytorch.Value(identifier, null, null, initializer);
}
return context.values.map(identifier);
});
argument = new pytorch.Argument(name, args);
} else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) {
const args = input.value.map((value) => {
if (value instanceof torch.Value) {
const identifier = pytorch.Utility.unique(value);
return context.values.map(identifier);
}
return value;
});
argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type));
} else {
throw new pytorch.Error('Unsupported input value');
}
this.inputs.push(argument);
}
for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null;
if (ret && ret.name) {
name = ret.name;
} else {
name = i === 0 && outputs.length === 1 ? 'output' : `${i}`;
}
let list = [output];
if (output.uses().length === 1 &&
output.uses()[0].user &&
output.uses()[0].user.kind() === 'prim::ListUnpack' &&
output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
list = output.uses()[0].user.outputs();
}
const args = list.map((output) => context.values.map(pytorch.Utility.unique(output)));
const argument = new pytorch.Argument(name, args);
this.outputs.push(argument);
}
const blocks = node.blocks();
for (let i = 0; i < blocks.length; i++) {
const block = blocks[i];
const nodes = Array.from(block.nodes());
if (nodes.length > 0) {
const name = `block${i}`;
const graph = { name, nodes: [], inputs: [], outputs: [] };
for (const v of block.inputs()) {
const identifier = pytorch.Utility.unique(v);
const value = context.values.map(identifier);
graph.inputs.push(new pytorch.Argument(v.debugName() || identifier, [value]));
}
for (const v of block.outputs()) {
const identifier = pytorch.Utility.unique(v);
graph.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
}
for (const n of nodes) {
if (n === block.param_node() || n === block.return_node()) {
continue;
}
graph.nodes.push(new pytorch.Node(execution, metadata, null, null, n, initializers, context));
}
const argument = new pytorch.Argument(name, graph, 'graph');
this.blocks.push(argument);
}
}
const sourceRange = node.sourceRange();
if (sourceRange) {
this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''), 'attribute'));
if (sourceRange.source()) {
const orig = sourceRange.source().findSourceRangeThatGenerated(sourceRange);
if (orig) {
this.metadata.push(new pytorch.Argument('generated', orig.toString(), 'attribute'));
}
}
}
} else if (torch && obj instanceof torch.fx.node.Node) {
if (obj.op === 'call_function') {
let name = null;
const target = obj.target;
if (target instanceof torch._ops.OpOverload) {
name = target.name();
} else if (target instanceof torch._ops.HigherOrderOperator) {
name = `${target.namespace}::${target.name}`;
} else if (builtins.isinstance(target, builtins.function)) {
name = target.__name__;
} else if (typeof target === 'string') {
// Handle unresolved operators
const match = target.match(/^torch\.ops\.([^.]+)\.(.+)$/);
if (!match) {
throw new pytorch.Error(`Unsupported target '${target}'.`);
}
const [, namespace, opname] = match;
name = `${namespace}::${opname}`;
} else {
throw new pytorch.Error(`Unsupported target '${target}'.`);
}
this.type = {
identifier: name,
name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]
};
const schema = obj.target._schema;
if (schema && schema.category) {
this.type.category = schema.category;
}
let args = obj.args.map((arg, index) => {
if (!schema) {
return ['', arg];
}
if (Array.isArray(schema.arguments) && index < schema.arguments.length) {
return [schema.arguments[index].name, arg];
}
if (schema.is_vararg) {
return ['', arg];
}
throw new pytorch.Error('Unsupported schema argument.');
});
const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg]));
args = args.concat(Array.from(obj.kwargs));
for (const [name, arg] of args) {
let type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
if (arg instanceof torch.fx.node.Node) {
let argument = null;
if (arg.op === 'get_attr' && arg.users.size === 1) {
const subgraph = context.function(arg);
if (subgraph) {
argument = new pytorch.Argument(name, subgraph, 'function');
}
}
if (!argument) {
const value = context.value(arg);
argument = new pytorch.Argument(name, [value]);
}
this.inputs.push(argument);
} else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) {
const list = arg.map((arg) => arg === null ? null : context.value(arg));
const argument = new pytorch.Argument(name, list);
this.inputs.push(argument);
} else if (Array.isArray(arg)) {
const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg);
const argument = new pytorch.Argument(name, list, type || 'attribute');
this.inputs.push(argument);
} else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) {
const argument = new pytorch.Argument(name, arg.toString(), type || 'attribute');
this.inputs.push(argument);
} else {
const primitive = typeof arg === 'number' || typeof arg === 'boolean' || typeof arg === 'string' || arg === null;
type = type === 'tensor' && primitive ? null : type;
const argument = new pytorch.Argument(name, arg, type || 'attribute');
this.inputs.push(argument);
}
}
let outputs = [obj];
if (obj.users.size > 1) {
const users = Array.from(obj.users.keys());
if (users.every((user) => user.op === 'call_function' && user.target.__module__ === 'operator' && user.target.__name__ === 'getitem')) {
outputs = new Array(obj.users.size);
for (const user of users) {
const [, index] = user.args;
outputs[index] = user;
}
}
}
for (let i = 0; i < outputs.length; i++) {
const node = outputs[i];
const value = context.value(node);
const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
const argument = new pytorch.Argument(name, [value]);
this.outputs.push(argument);
}
for (const [name, value] of obj.meta) {
if (name === 'val' || name === 'torch_fn' ||
(Array.isArray(value) && value.length === 0) ||
(value instanceof Map && value.size === 0)) {
continue;
}
if (typeof value === 'string') {
const argument = new pytorch.Argument(name, value, 'string');
this.metadata.push(argument);
} else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) {
const argument = new pytorch.Argument(name, value, 'string[]');
this.metadata.push(argument);
} else if (value instanceof Map && value.size > 0) {
// const argument = new pytorch.Argument(name, Object.fromEntries(Array.from(value)));
// this.metadata.push(argument);
} else {
// const argument = new pytorch.Argument(name, value);
// this.metadata.push(argument);
}
}
} else if (obj.op === 'placeholder') {
this.type = { name: obj.op };
{
const value = context.value(obj);
const argument = new pytorch.Argument('value', [value]);
this.inputs.push(argument);
}
{
const node = new torch.fx.node.Node(null, obj.name);
node.meta = obj.meta;
const value = context.value(node);
const argument = new pytorch.Argument('value', [value]);
this.outputs.push(argument);
}
} else if (obj.op === 'get_attr') {
this.type = { name: obj.op };
const subgraph = context.function(obj);
if (subgraph) {
this.inputs.push(new pytorch.Argument('name', subgraph, 'function'));
} else {
this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
}
const value = context.value(obj);
this.outputs.push(new pytorch.Argument('value', [value]));
} else if (obj.op === 'root') {
this.type = { name: obj.op };
} else {
throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
}
} else {
if (torch && obj instanceof torch.ScriptObject) {
type = obj._type().qualified_name();
obj = obj._ivalue;
} else if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
type = obj._c._type();
const target = {
_modules: obj._modules,
_parameters: obj._parameters,
_buffers: obj._buffers,
};
for (let i = 0; i < type.numAttributes(); i++) {
if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) {
const k = type.getAttributeName(i);
target[k] = obj.__getattr__(k);
}
}
type = obj._c.qualified_name;
obj = target;
}
if (!type) {
if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
type = obj._c.qualified_name;
} else if (pytorch.Utility.isInstance(obj, 'builtins.function')) {
type = `${obj.__module__}.${obj.__name__}`;
obj = {};
} else if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
} else {
type = 'builtins.object';
}
}
if (type instanceof nnapi.Graph) {
this.type = type;
} else {
const key = type.startsWith('__torch__.') ? type.substring(10) : type;
const value = metadata.type(key);
this.type = value ? { ...value } : { name: type };
this.type.identifier = type;
}
stack = stack || new Set();
const weights = pytorch.Utility.weights(obj);
if (weights) {
const type = this.type.name;
this.type = new pytorch.Graph(execution, metadata, 'weights', '', weights);
this.type.name = type;
} else if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) {
// continue
} else if (obj && pytorch.Utility.isInstance(obj, '__torch__.torch.classes._nnapi.Compilation')) {
// continue
} else if (obj && type === 'builtins.bytearray') {
const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]');
this.inputs.push(argument);
} else if (obj) {
const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []);
const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
for (const [name, value] of list) {
if (name === '__class__' || name === '__name__') {
continue;
} else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) {
continue;
} else if (pytorch.Utility.isInstance(value, 'builtins.set') && value instanceof Set && value.size === 0) {
continue;
} else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) {
continue;
} else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) {
continue;
}
let parameters = null;
if ((name === '_parameters' || name === '_buffers') && value instanceof Map) {
parameters = value;
} else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) {
parameters = new Map([[name, value]]);
}
if (parameters) {
for (const [name, value] of parameters) {
const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)];
const visible = inputs.has(name) ? inputs.get(name).visible || true : true;
const args = list.filter((value) => value !== null && !value.__origin__).map((value) => {
const name = value && value.name ? value.name : '';
const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name;
let tensor = null;
if (initializers && initializers.has(value)) {
tensor = initializers.get(value);
} else {
value = value.__source__ ? value.__source__ : value;
tensor = value ? new pytorch.Tensor(context, identifier, value) : null;
}
return new pytorch.Value(identifier, null, null, tensor);
});
const argument = new pytorch.Argument(name, args, null, visible);
this.inputs.push(argument);
if (value && value.__variable__) {
const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]);
this.outputs.push(argument);
}
}
continue;
}
if (pytorch.Utility.isTensor(value)) {
const tensor = new pytorch.Tensor(context, '', value);
const argument = new pytorch.Argument(name, tensor, 'tensor');
this.inputs.push(argument);
} else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) {
const node = new pytorch.Node(execution, metadata, null, value.toString(), {}, null, context);
const argument = new pytorch.Argument(name, node, 'object');
this.inputs.push(argument);
} else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) {
const tensors = value.map((value) => value === null ? value : new pytorch.Tensor(context, '', value));
const argument = new pytorch.Argument(name, tensors, 'tensor[]');
this.inputs.push(argument);
} else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) {
const tensor = new numpy.Tensor(value);
const argument = new pytorch.Argument(name, tensor, 'tensor');
this.inputs.push(argument);
} else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) {
const argument = new pytorch.Argument(name, value, 'string[]');
this.inputs.push(argument);
} else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) {
const argument = new pytorch.Argument(name, value, 'attribute');
this.inputs.push(argument);
} else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') &&
value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
stack.add(value);
const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack);
stack.delete(value);
return node;
});
const argument = new pytorch.Argument(name, list, 'object[]');
this.inputs.push(argument);
} else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) {
const argument = new pytorch.Argument(name, value, 'attribute');
this.inputs.push(argument);
} else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
const list = value.filter((value) => !stack.has(value));
const nodes = list.map((value) => {
stack.add(value);
const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
stack.delete(value);
return node;
});
const argument = new pytorch.Argument(name, nodes, 'object[]');
this.inputs.push(argument);
} else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) {
stack.add(value);
const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
stack.delete(value);
const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value);
const argument = new pytorch.Argument(name, node, 'object', visible);
this.inputs.push(argument);
} else {
let schema = metadata.attribute(this.type.identifier, name);
schema = name === 'training' ? { type: 'boolean', visible: false } : schema;
let visible = true;
let obj = value;
const type = schema && schema.type ? schema.type : 'attribute';
if (schema) {
if (schema.visible === false) {
visible = false;
} else if (schema.default !== undefined) {
if (Array.isArray(obj)) {
if (Array.isArray(schema.default)) {
visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]);
} else {
visible = !obj.every((item) => item === schema.default);
}
} else {
visible = obj !== schema.default;
}
}
}
if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
obj = '?';
}
const argument = new pytorch.Argument(name, obj, type, visible);
this.inputs.push(argument);
}
}
}
}
}
};
pytorch.Tensor = class {
constructor(context, name, tensor) {
this.name = name || '';
this.attributes = [];
tensor = tensor.data ? tensor.data : tensor;
const storage = tensor.storage();
this.type = context.type(tensor);
const layout = this.type.layout;
const size = this.type.shape.dimensions || [];
if (layout) {
this.indices = new pytorch.Tensor(context, '', tensor.indices);
this._values = new pytorch.Tensor(context, '', tensor.values);
} else {
this.encoding = '<';
this.indices = null;
this.stride = tensor.stride();
const stride = this.stride;
const offset = tensor.storage_offset();
if (storage) {
this._data = storage.data;
let length = 0;
if (!Array.isArray(stride)) {
length = storage.size();
} else if (size.every((v) => v !== 0)) {
length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
}
if (storage && typeof storage.size === 'function') {
if (offset !== 0 || length !== storage.size()) {
const itemsize = storage.dtype.itemsize();
this._offset = itemsize * offset;
this._length = itemsize * length;
}
}
}
}
const type = tensor.__class__ || {};
if (type.tensor_attribute_names) {
for (const name of type.tensor_attribute_names) {
let value = tensor[name];
if (value !== undefined) {
if (value && typeof value.__reduce__ === 'function') {
value = value.__reduce__();
}
const attribute = new pytorch.Argument(name, value, 'attribute');
this.attributes.push(attribute);
}
}
}
if (type.tensor_data_names) {
for (const name of type.tensor_data_names) {
const value = tensor[name];
if (value !== undefined && pytorch.Utility.isTensor(value)) {
const attribute = new pytorch.Argument(name, new pytorch.Tensor(context, name, value), 'tensor');
this.attributes.push(attribute);
}
}
}
}
get values() {
const type = this.type.layout;
if (type && type.startsWith('sparse.')) {
return this._values;
}
if (this._data instanceof Uint8Array) {
return this._data;
}
if (this._data && this._offset !== undefined) {
const stream = this._data;
const position = stream.position;
stream.seek(this._offset);
const values = stream.peek(this._length);
stream.seek(position);
return values;
}
if (this._data) {
return this._data.peek();
}
return null;
}
};
pytorch.TensorType = class {
constructor(dataType, shape, layout) {
this.dataType = dataType;
this.shape = shape;
this.layout = layout;
}
toString() {
return this.dataType + this.shape.toString();
}
};
pytorch.TensorShape = class {
constructor(dimensions = []) {
this.dimensions = dimensions;
}
toString() {
if (this.dimensions && this.dimensions.length > 0) {
return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
}
return '';
}
};
pytorch.Context = class {
constructor(execution, metadata) {
this.execution = execution;
this.torch = execution ? execution.__import__('torch') : null;
this.metadata = metadata;
this.values = new Map();
this.modules = new Map();
}
type(tensor) {
let dataType = tensor.dtype.__reduce__();
switch (dataType) {
case 'float8_e5m2': dataType = 'float8e5m2'; break;
case 'float8_e5m2fnuz': dataType = 'float8e5m2fnuz'; break;
case 'float8_e4m3fn': dataType = 'float8e4m3fn'; break;
case 'float8_e4m3fnuz': dataType = 'float8e4m3fnuz'; break;
case 'float8_e8m0fnu': dataType = 'float8e8m0fnu'; break;
case 'float4_e2m1fn_x2': dataType = 'float4e2m1fnx2'; break;
default: break;
}
const size = tensor.size ? tensor.size() : tensor.shape;
const shape = new pytorch.TensorShape(size || []);
const layout = tensor.layout ? tensor.layout.__str__() : null;
if (layout && layout.startsWith('torch.sparse_')) {
return new pytorch.TensorType(dataType, shape, layout.split('.').pop().replace('_', '.'));
}
return new pytorch.TensorType(dataType, shape);
}
value(obj) {
const torch = this.torch;
if (obj instanceof torch.fx.node.Node) {
if (!this.values.has(obj)) {
let type = null;
const val = obj.meta ? obj.meta.get('val') : null;
if (val && val.dtype) {
type = this.type(val);
}
const value = new pytorch.Value(obj.name, type);
this.values.set(obj, value);
}
return this.values.get(obj);
}
return null;
}
function(obj) {
const torch = this.torch;
if (obj instanceof torch.fx.node.Node) {
let subgraph = this.modules.get(obj);
if (subgraph) {
if (subgraph instanceof pytorch.Graph === false) {
subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph);
this.modules.set(obj, subgraph);
}
return subgraph;
}
}
return null;
}
graph(target, module, inputs) {
const graph = module.graph;
if (module.named_modules) {
const modules = module.named_modules();
for (const obj of graph.nodes) {
if (obj.op === 'get_attr') {
const submodule = modules.get(obj.target);
if (submodule && submodule.graph) {
this.modules.set(obj, submodule);
}
}
}
}
let controlDependency = null;
for (const obj of graph.nodes) {
if (obj.op === 'placeholder') {
if (inputs) {
const value = this.value(obj);
const argument = new pytorch.Argument(obj.name, [value]);
target.inputs.push(argument);
}
continue;
}
if (obj.op === 'call_function') {
if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
continue;
}
}
if (obj.op === 'get_attr') {
if (this.modules.has(obj) && obj.users.size === 1) {
continue;
}
}
if (obj.op === 'output') {
for (const output of obj.args) {
if (output === null || output === undefined) {
continue;
}
if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
continue;
}
const value = this.value(output);
const argument = new pytorch.Argument(output.name, [value]);
target.outputs.push(argument);
}
continue;
}
const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this);
target.nodes.push(node);
if (controlDependency) {
node.controlDependencies = node.controlDependencies || [];
node.controlDependencies.push(controlDependency);
controlDependency = null;
}
if (obj.op === 'call_function' && obj.users.size === 0) {
controlDependency = node.outputs[0].value[0];
}
}
}
};
pytorch.Reader = class {
static async open(context) {
const types = [
pytorch.Reader.Zip,
pytorch.Reader.Pickle,
pytorch.Reader.Tar,
pytorch.Reader.data_pkl,
pytorch.Reader.torch_utils,
pytorch.Reader.Mobile,
pytorch.Reader.ModelJson,
pytorch.Reader.IR,
pytorch.Reader.Index,
pytorch.Reader.ExportedProgram
];
for (const type of types) {
// eslint-disable-next-line no-await-in-loop
const reader = await type.open(context);
if (reader) {
return reader;
}
}
return null;
}
constructor() {
this._events = [];
}
async read() {
}
on(event, callback) {
this._events.push([event, callback]);
}
};
pytorch.Reader.Tar = class extends pytorch.Reader {
static async open(context) {
const entries = await context.peek('tar');
if (entries instanceof Map && entries.has('pickle')) {
return new pytorch.Reader.Tar(entries);
}
return null;
}
constructor(entries) {
super();
this.type = 'pytorch.tar';
this.entries = entries;
}
async read() {
this.format = 'PyTorch v0.1.1';
const execution = new python.Execution();
for (const event of this._events) {
execution.on(event[0], event[1]);
}
const torch = execution.__import__('torch');
this.module = torch.load(this.entries);
delete this.entries;
}
};
pytorch.Reader.Pickle = class extends pytorch.Reader {
static async open(context) {
const stream = context.stream;
const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
return new pytorch.Reader.Pickle(stream);
}
return null;
}
constructor(stream) {
super();
this.type = 'pytorch.pickle';
this.stream = stream;
}
async read() {
this.format = 'PyTorch v0.1.10';
const data = this.stream.length < 0x7ffff000 ? this.stream.peek() : this.stream;
delete this.stream;
const execution = new python.Execution();
for (const event of this._events) {
execution.on(event[0], event[1]);
}
const torch = execution.__import__('torch');
this.module = torch.load(data);
}
};
pytorch.Reader.data_pkl = class extends pytorch.Reader {
static async open(context) {
const obj = await context.peek('pkl');
if (obj) {
if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
if (name.startsWith('__torch__.')) {
return new pytorch.Reader.data_pkl('', obj);
}
}
if (pytorch.Utility.isTensor(obj)) {
return new pytorch.Reader.data_pkl('tensor', obj);
}
if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
return new pytorch.Reader.data_pkl('tensor', obj);
}
if (obj instanceof Map) {
const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
if (entries.length > 0) {
return new pytorch.Reader.data_pkl('tensor', obj);
}
} else if (!Array.isArray(obj)) {
const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
if (entries.length > 0) {
return new pytorch.Reader.data_pkl('tensor', obj);
}
}
for (const key of ['', 'model', 'net']) {
const module = key === '' ? obj : obj[key];
if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) {
return new pytorch.Reader.data_pkl('module', module);
}
}
}
return null;
}
constructor(type, module) {
super();
this.type = 'pytorch.data.pkl';
this.format = 'PyTorch Pickle';
this.module = module;
}
async read() {
}
};
pytorch.Reader.torch_utils = class extends pytorch.Reader {
static async open(context) {
const stream = context.stream;
if (stream && stream.length > 1) {
const buffer = stream.peek(Math.min(1024, stream.length));
if (buffer[0] === 0x80) {
const content = String.fromCharCode.apply(null, buffer);
if (content.indexOf('torch_utils') !== -1) {
const obj = await context.peek('pkl');
if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) {
return new pytorch.Reader.torch_utils(obj);
}
}
}
}
return null;
}
constructor(obj) {
super();
this.type = 'pytorch.torch_utils';
this.obj = obj;
}
async read() {
this.format = 'PyTorch torch_utils';
this.module = this.obj;
delete this.obj;
}
};
pytorch.Reader.Mobile = class extends pytorch.Reader {
static async open(context) {
const reader = await context.peek('flatbuffers.binary');
if (reader && reader.identifier === 'PTMF') {
return new pytorch.Reader.Mobile(context);
}
return null;
}
constructor(context) {
super();
this.type = 'pytorch.mobile';
this.context = context;
}
async read(metadata) {
const execution = new pytorch.Execution(null, metadata);
for (const event of this._events) {
execution.on(event[0], event[1]);
}
const stream = this.context.stream;
const torch = execution.__import__('torch');
torch.mobile = await this.context.require('./pytorch-schema');
torch.mobile = torch.mobile.torch.jit.mobile;
this.module = torch.jit.jit_module_from_flatbuffer(stream);
const version = this.module._c._bytecode_version.toString();
this.format = pytorch.Utility.format('PyTorch Mobile', version);
delete this.context;
}
};
pytorch.Reader.Zip = class extends pytorch.Reader {
static async open(context) {
const entries = await context.peek('zip');
if (entries instanceof Map && entries.size > 0) {
let prefix = 0;
const paths = Array.from(entries.keys()).map((path) => path.replace(/\\/g, '/').split('/').reverse());
for (let set = new Set(); set && paths.length > 0;) {
set = new Set(paths.map((path) => path.length > 1 ? path.pop() : null));
set = set.size > 1 || set.keys().next().value === null ? null : set;
prefix += set ? set.keys().next().value.length + 1 : 0;
}
const records = new Map(Array.from(entries).map(([name, value]) => [name.substring(prefix), value]));
if (records.has('model.json')) {
return null;
}
if (records.has('data.pkl')) {
return new pytorch.Reader.Zip(entries);
}
if (records.has('.data/version') && !records.has('archive_format')) {
return new pytorch.Reader.Package(entries);
}
}
return null;
}
constructor(entries) {
super();
this.type = 'pytorch.zip';
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
this._entries = entries;
}
async read(metadata) {
this.execution = new pytorch.Execution(null, metadata);
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
const torch = this.execution.__import__('torch');
const reader = new torch.PyTorchFileReader(this._entries);
let torchscript = reader.has_record('constants.pkl');
const version = reader.version();
if (torchscript) {
metadata.register(this.execution);
this.module = torch.jit.load(reader);
torchscript = this.module._c._has_method('forward');
if (torchscript) {
// console.log(this.module.graph.toString());
torch._C._jit_pass_inline(this.module.graph);
// console.log(this.module.graph.toString());
}
} else {
const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]);
const entries = new Map(records);
this.module = torch.load(entries);
}
const name = torchscript ? 'TorchScript' : 'PyTorch';
this.format = pytorch.Utility.format(name, version);
delete this._model;
delete this._entries;
}
};
pytorch.Reader.ModelJson = class extends pytorch.Reader {
static async open(context) {
const identifier = context.identifier;
if (identifier === 'model.json') {
const model = await context.peek('json');
if (model && model.mainModule) {
const entries = new Map();
entries.set('model.json', context.stream);
return new pytorch.Reader.ModelJson(context, entries, model);
}
}
return null;
}
constructor(context, entries, model) {
super();
this.type = 'pytorch.model.json';
this._context = context;
this._entries = entries;
this._model = model;
}
async read(metadata) {
pytorch.proto = await this._context.require('./pytorch-proto');
const keys = [
'attributes.pkl',
'version',
...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key)
];
const walk = (module) => {
if (module.torchscriptArena && module.torchscriptArena.key) {
keys.push(module.torchscriptArena.key);
}
for (const submodule of module.submodules || []) {
walk(submodule);
}
};
walk(this._model.mainModule);
const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null)));
for (let i = 0; i < keys.length; i++) {
if (values[i]) {
this._entries.set(keys[i], values[i]);
}
}
this.execution = new pytorch.Execution(null, metadata);
this.execution.proto = pytorch.proto;
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
const torch = this.execution.__import__('torch');
const reader = new torch.PyTorchFileReader(this._entries);
if (this._model && this._model.producerName) {
this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : '');
}
this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
metadata.register(this.execution);
this.module = torch.jit.load(reader);
if (this.module._c._has_method('forward')) {
// console.log(this.module.graph.toString());
torch._C._jit_pass_inline(this.module.graph);
// console.log(this.module.graph.toString());
}
delete this._context;
delete this._model;
delete this._entries;
}
};
pytorch.Reader.IR = class extends pytorch.Reader {
static async open(context) {
const reader = await context.read('text', 0x100);
if (reader && reader.length > 0) {
const line = reader.read('\n');
if (line.startsWith('graph(')) {
return new pytorch.Reader.IR(context);
}
}
return null;
}
constructor(context) {
super();
this.type = 'pytorch.ir';
this.context = context;
}
async read(metadata) {
this.format = 'TorchScript IR';
this.execution = new pytorch.Execution(null, metadata);
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
// this.execution.graph;
// context reader = await context.read('text', 0x100);
throw new pytorch.Error('TorchScript IR parser not implemented.');
}
};
pytorch.Reader.Index = class extends pytorch.Reader {
static async open(context) {
const obj = await context.peek('json');
if (obj && obj.weight_map) {
const entries = Object.entries(obj.weight_map);
if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
return new pytorch.Reader.Index(context, entries);
}
}
return null;
}
constructor(context, entries) {
super();
this.type = 'pytorch.index';
this.context = context;
this._entries = entries;
}
async read(metadata) {
this.format = 'PyTorch';
const weight_map = new Map(this._entries);
const keys = new Set(weight_map.keys());
const files = Array.from(new Set(weight_map.values()));
const contexts = await Promise.all(files.map((name) => this.context.fetch(name)));
this.execution = new pytorch.Execution(null, metadata);
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
const torch = this.execution.__import__('torch');
const archives = await Promise.all(contexts.map((context) => context.peek('zip')));
const formats = new Set(archives.map((entries) => {
const reader = new torch.PyTorchFileReader(entries);
const version = reader.version();
return pytorch.Utility.format('PyTorch', version);
}));
if (formats.size === 1) {
this.format = formats.values().next().value;
}
const shards = archives.map((entries) => {
return torch.load(entries);
});
const entries = new Map();
for (const shard of shards) {
for (const [key, value] of Array.from(shard)) {
if (keys.has(key)) {
entries.set(key, value);
}
}
}
this.module = entries;
delete this.context;
delete this._entries;
}
};
pytorch.Reader.ExportedProgram = class extends pytorch.Reader {
static async open(context) {
const program = await context.peek('json');
if (program && program.schema_version && program.graph_module) {
return new pytorch.Reader.ExportedProgram(context, program);
}
if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) {
const buffer = context.stream.peek();
const archive_format = String.fromCharCode.apply(null, buffer);
if (archive_format === 'pt2') {
return new pytorch.Reader.ExportedProgram(context, null, context);
}
}
return null;
}
constructor(context, exported_program, archive_format) {
super();
this.type = 'pytorch.export';
this.context = context;
this.archive_format = archive_format;
this.exported_program = exported_program;
}
async read(metadata) {
this.format = 'PyTorch Export';
const f = new Map();
const exported_programs = new Map();
if (this.archive_format) {
for (const name of this.context.container.entries.keys()) {
const match = name.match(/^models\/([^/]+)\.json$/);
if (match) {
const [, model_name] = match;
/* eslint-disable no-await-in-loop */
const model = await this.context.fetch(`models/${model_name}.json`);
const exported_program = await model.read('json');
exported_programs.set(model_name, exported_program);
f.set(`models/${model_name}.json`, exported_program);
const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`, 'zip');
f.set(`data/sample_inputs/${model_name}.pt`, sample_inputs);
const weights_config = await this._fetch(`data/weights/${model_name}_weights_config.json`, 'json');
if (weights_config) {
f.set(`data/weights/${model_name}_weights_config.json`, weights_config);
for (const payload_meta of Object.values(weights_config.config)) {
const type = payload_meta.use_pickle ? 'zip' : 'binary';
const weight_data = await this._fetch(`data/weights/${payload_meta.path_name}`, type);
if (weight_data) {
f.set(`data/weights/${payload_meta.path_name}`, weight_data);
}
}
} else {
const weights = await this._fetch(`data/weights/${model_name}.pt`, 'zip');
f.set(`data/weights/${model_name}.pt`, weights);
}
const constants_config = await this._fetch(`data/constants/${model_name}_constants_config.json`, 'json');
if (constants_config) {
f.set(`data/constants/${model_name}_constants_config.json`, constants_config);
for (const payload_meta of Object.values(constants_config.config)) {
// eslint-enable no-await-in-loop
const type = payload_meta.use_pickle ? 'zip' : 'binary';
const constant_data = await this._fetch(`data/constants/${payload_meta.path_name}`, type);
if (constant_data) {
f.set(`data/constants/${payload_meta.path_name}`, constant_data);
}
}
} else {
const constants = await this._fetch(`data/constants/${model_name}.pt`);
f.set(`data/constants/${model_name}.pt`, constants);
}
/* eslint-enable no-await-in-loop */
}
}
const byteorder = await this._fetch('byteorder', 'text') || 'little';
f.set('byteorder', byteorder);
} else {
this.version = await this._fetch('version', 'text') || '';
this.version = this.version.split('\n').shift().trim();
const weights = await this._fetch('serialized_state_dict.pt', 'zip') || await this._fetch('serialized_state_dict.json', 'zip');
const constants = await this._fetch('serialized_constants.pt', 'zip') || await this._fetch('serialized_constants.json', 'zip');
const sample_inputs = await this._fetch('serialized_example_inputs.pt', 'zip');
f.set('models/model.json', this.exported_program);
f.set('data/weights/model.pt', weights);
f.set('data/constants/model.pt', constants);
f.set('data/sample_inputs/model.pt', sample_inputs);
exported_programs.set('', this.exported_program);
}
if (!this.version) {
const versions = new Set();
for (const exported_program of exported_programs.values()) {
const schema_version = exported_program.schema_version;
if (schema_version && schema_version.major && schema_version.minor) {
versions.add(`${schema_version.major}.${schema_version.minor}`);
}
}
if (versions.size === 1) {
this.version = versions.values().next().value;
}
}
this.format = this.version ? `${this.format} v${this.version}` : this.format;
this.execution = new python.Execution();
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
metadata.register(this.execution);
const torch = this.execution.__import__('torch');
for (const exported_program of exported_programs.values()) {
if (exported_program.graph_module.graph.constants) {
// eslint-disable-next-line no-await-in-loop
const zip = await import('./zip.js');
const constants = exported_program.graph_module.graph.constants;
for (const key of Object.keys(constants)) {
const value = constants[key];
const str = atob(value);
const buffer = new Uint8Array(str.length);
for (let i = 0; i < str.length; i++) {
buffer[i] = str.charCodeAt(i);
}
const archive = zip.Archive.open(buffer);
constants[key] = archive.entries;
}
}
}
delete this.exported_program;
delete this.context;
const pt2_contents = torch.export.pt2_archive._package.load_pt2(f);
this.modules = pt2_contents.exported_programs;
}
async _fetch(name, type) {
try {
const context = await this.context.fetch(name);
if (context) {
switch (type) {
case 'zip':
return await context.peek('zip');
case 'json':
return await context.read('json');
case 'text': {
const reader = await context.read('text');
if (reader) {
return reader.read();
}
break;
}
case 'binary': {
if (context && context.stream) {
return context.stream.peek();
}
break;
}
default: {
throw new pytorch.Error(`Unsupported context type '${type}.`);
}
}
}
} catch {
// continue regardless of error
}
return null;
}
};
pytorch.Execution = class extends python.Execution {
constructor(sources, metadata) {
super(sources);
this._metadata = metadata;
// eslint-disable-next-line consistent-this
const execution = this;
const torch = this.torch;
this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => {
const cu = new torch.jit.CompilationUnit();
cu.execution = execution;
const stream = f;
const reader = flatbuffers.BinaryReader.open(stream);
const module = torch.mobile.serialization.Module.create(reader);
const loader = new torch._C.FlatBuffersLoader(cu);
const cpp_module = loader.parseModule(module);
// parse_and_initialize_jit_module
// const mobilem = parse_and_initialize_mobile_module_for_jit(data, jit_files, jit_constants);
// const m = jitModuleFromSourceAndConstants(mobilem._ivalue(), jit_files, jit_constants, mobilem.bytecode_version());
// throw new pytorch.Error('torch.jit.mobile.serialization.Module not supported.');
return torch.jit._script.wrap_cpp_module(cpp_module);
});
this.registerType('__torch__.torch.classes._nnapi.Compilation', class {
constructor() {
this.__hide__ = true;
}
__init__() {
}
init(serialized_model_tensor, parameter_buffers) {
this.serialized_model_tensor = serialized_model_tensor;
this.parameter_buffers = parameter_buffers;
const buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
/*
let buffers = [];
if (!pytorch.Utility.isInstance(parameter_buffers, 'torch.Value')) {
buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
}
*/
const serialized_model = serialized_model_tensor.storage().data;
this.serialized_model = new nnapi.SerializedModel(serialized_model, buffers);
}
run(inputs, outputs) {
execution.variable(this.serialized_model_tensor);
this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1;
const type = new nnapi.Graph(this.serialized_model);
const node = execution.graph.create(type, 0);
execution.graph.insertNode(node);
for (const tensor of inputs) {
const value = execution.variable(tensor);
node.addInput(value);
}
for (const tensor of outputs) {
execution.variable(tensor, node);
}
}
});
this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class {
__setstate__(state) {
if (state[0] !== '2') {
throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
}
const [/* pack_version */, tensors, opt_tensors] = state;
const packed_config = tensors[0].tolist();
this.weight = tensors[1];
this.bias = opt_tensors[0];
this.stride = [packed_config[1], packed_config[2]];
this.padding = [packed_config[3], packed_config[4]];
this.dilation = [packed_config[5], packed_config[6]];
this.output_padding = [packed_config[7], packed_config[8]];
this.groups = packed_config[9];
}
});
this.registerType('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', class {
__setstate__(state) {
if (state[0] !== '2') {
throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
}
const [/* pack_version */, tensors, opt_tensors] = state;
const packed_config = tensors[0].tolist();
this.weight = tensors[1];
this.bias = opt_tensors[0];
this.stride = [packed_config[1], packed_config[2]];
this.padding = [packed_config[3], packed_config[4]];
this.dilation = [packed_config[5], packed_config[6]];
this.output_padding = [packed_config[7], packed_config[8]];
this.groups = packed_config[9];
}
});
this.registerType('__torch__.torch.classes.quantized.LinearPackedParamsBase', class {
__setstate__(state) {
[this.weight, this.bias] = state;
}
});
this.registerType('__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', class {
__setstate__(state) {
[this.version, this.tensors, this.doubles, this.longs] = state;
}
});
this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class {
__setstate__(state) {
[this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state;
}
});
this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
__setstate__(state) {
[this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
}
});
this.registerType('__torch__.torch.classes.xnnpack.LinearOpContext', class {
__setstate__(state) {
[this.weight, this.bias, this.output_min, this.output_max] = state;
}
});
this.registerType('__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', class {
__setstate__(state) {
[this.weight, this.bias, this.stride, this.padding, this.output_padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
}
});
this.registerType('__torch__.torch.classes.tensorrt.Engine', class {
__setstate__(state) {
[this.abi_target, this.name, this.device, this.engine, this.input_binding_names, this.output_binding_names, this.hw_compatible, this.serialized_metadata, this.target_platform] = state;
}
});
const custom_classes = [
{ name: '__torch__.torch.classes._nnapi.Compilation', methods: [
'__init__(__torch__.torch.classes._nnapi.Compilation self) -> NoneType',
'init(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> NoneType',
'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType',
'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType'
] },
{ name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
{ name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
{ name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' },
{ name: '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', attributes: 'int version, Tensor[] tensors, float[] doubles, int[] longs', methods: [] },
{ name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' },
{ name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
{ name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' },
{ name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
{ name: '__torch__.torch.classes.tensorrt.Engine' }
];
for (const known_type of custom_classes) {
const prefix = new torch._C.QualifiedName(known_type.name);
const type = torch.ClassType.create(known_type.name, this._compilation_unit, false);
for (const known_method of known_type.methods || []) {
const schema = new torch.FunctionSchema(known_method);
const name = new torch._C.QualifiedName(prefix, schema.name);
const fn = new torch._C.BuiltinOpFunction(name, schema);
type.addMethod(fn);
}
if (known_type.attributes) {
const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`);
for (const arg of schema.arguments) {
type.addAttribute(arg.name, arg.real_type);
}
}
torch._C.registerCustomClass(type);
}
}
call(target, name, args, keywords, context) {
const ast = this.ast;
const torch = this.torch;
if (target instanceof ast.Name && target.id === 'torch') {
const fn = torch.ops.aten.__getattr__(name);
if (fn) {
const evalArgs = args.map((arg) => this.expression(arg, context));
return fn.__call__(...evalArgs);
}
}
if (target instanceof ast.Attribute && target.value instanceof ast.Name && target.value.id === 'ops') {
const module = torch.ops[target.attr];
if (!module) {
throw new pytorch.Error(`Unknown torch.ops module '${target.attr}'.`);
}
const fn = module.__getattr__(name);
if (fn) {
const evalArgs = args.map((arg) => this.expression(arg, context));
return fn.__call__(...evalArgs);
}
}
return super.call(target, name, args, keywords, context);
}
invoke(target, args) {
if (target && Array.isArray(target.__bases__) && target.__bases__.length > 0 && target.__bases__[0] === this.enum.Enum) {
const instance = new target();
instance.value = args;
return instance;
}
return super.invoke(target, args);
}
base(expr, context) {
const ast = this.ast;
if (expr instanceof ast.Name) {
switch (expr.id) {
case 'Enum': return this.enum.Enum;
default: break;
}
}
return this.expression(expr, context);
}
};
pytorch.Reader.Package = class extends pytorch.Reader {
constructor(entries) {
super();
this.type = 'pytorch.package';
this.entries = entries;
}
async read(metadata) {
this.execution = new pytorch.Execution(null, metadata);
for (const event of this._events) {
this.execution.on(event[0], event[1]);
}
const torch = this.execution.__import__('torch');
const reader = new torch.PyTorchFileReader(this.entries);
const version = reader.version();
this.format = pytorch.Utility.format('PyTorch Package', version);
this.modules = new Map();
const records = reader.get_all_records().filter((name) => {
if (!name.startsWith('.data/') && !name.endsWith('.py')) {
const stream = reader.get_record(name);
if (stream && stream.length > 2) {
const signature = stream.peek(2);
if (signature[0] === 0x80 && signature[1] < 7) {
return true;
}
}
}
return false;
});
const entries = records.map((name) => {
const parts = name.split('/');
const resource = parts.pop();
const module = parts.join('.');
return [module, resource];
});
if (entries.length > 0) {
for (const name of reader.get_all_records()) {
if (!name.startsWith('.data/') && name.endsWith('.py')) {
const stream = reader.get_record(name);
const buffer = stream.peek();
this.execution.add(name, buffer);
}
}
metadata.register(this.execution);
const importer = new torch.package.PackageImporter(reader);
for (const entry of entries) {
const module = importer.load_pickle(entry[0], entry[1]);
const key = `${entry[0].replace(/\./, '/')}/${entry[1]}`;
this.modules.set(key, module);
}
}
delete this.entries;
}
};
pytorch.MemoryFormat = {
Contiguous: 0,
Preserve: 1,
ChannelsLast: 2,
ChannelsLast3d: 3
};
pytorch.Layout = {
Strided: 0,
Sparse: 1,
Mkldnn: 2
};
pytorch.Utility = class {
static isTensor(obj) {
const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
switch (name) {
case 'torch':
case 'torch.cuda':
return obj.__class__.__name__.endsWith('Tensor');
case 'torch.nn.parameter':
return obj.__class__.__name__ === 'Parameter';
default:
return false;
}
}
static toTensor(obj) {
const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
switch (name) {
case 'torch':
case 'torch.cuda':
return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
case 'torch.nn.parameter':
if (obj.__class__.__name__ === 'Parameter') {
const data = obj.data;
if (typeof obj.__name__ === 'string') {
data.__name__ = obj.__name__;
}
return data;
}
return null;
default:
return null;
}
}
static toType(type) {
switch (type.kind()) {
case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`;
case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`;
case 'BoolType': return 'boolean';
case 'IntType': return 'int64';
case 'FloatType': return 'float32';
case 'StringType': return 'string';
case 'ComplexType': return 'complex';
case 'NumberType': return 'scalar';
case 'TensorType': return 'tensor';
case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`;
case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`;
case 'DeviceObjType': return 'device';
case 'SymIntType': return 'SymInt';
case 'ScalarTypeType': return 'ScalarType';
case 'MemoryFormat': return 'MemoryFormat';
case 'Layout': return 'Layout';
case 'VarType': return type.annotation_str;
case 'NoneType': return 'None';
case 'AnyType': return 'object';
case 'AnyListType': return 'list';
case 'AnyTupleType': return 'tuple';
case 'ClassType': return type.annotation_str;
case 'EnumType': return type.annotation_str;
default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
}
}
static toString(ivalue) {
if (ivalue.isInt()) {
return ivalue.toInt();
}
if (ivalue.isDouble()) {
return ivalue.toDouble();
}
if (ivalue.isEnum()) {
return ivalue.toEnumHolder().name();
}
if (ivalue.isList()) {
return ivalue.toList().map((item) => pytorch.Utility.toString(item));
}
throw new pytorch.Error(`Unsupported IValue '${ivalue.tag}.`);
}
static constant(node, name) {
const kind = node.kindOf(name);
switch (kind) {
case 's': return node.s(name);
case 'i': return node.i(name);
case 'f': return node.f(name);
case 'ss': return node.ss(name);
case 'ival': return node.ival(name);
default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
}
}
static unique(value) {
return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
}
static isObject(obj) {
const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
case '__torch__.torch.classes.rnn.CellParamsBase':
case '__torch__.torch.classes.rnn.CellParamsBase[]':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase':
return true;
default:
return false;
}
}
static isSubclass(value, name) {
if (value && value.__module__ && value.__name__) {
return name === `${value.__module__}.${value.__name__}`;
} else if (value && value.__bases__) {
return value.__bases__.some((obj) => pytorch.Utility.isSubclass(obj, name));
}
return false;
}
static isInstance(value, name) {
return value && value.__class__ ? pytorch.Utility.isSubclass(value.__class__, name) : false;
}
static format(name, value) {
// https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
// kProducedFileFormatVersion
const versions = new Map([
['1', 'v1.3'],
['2', 'v1.5'], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
['3', 'v1.6'], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
['4', 'v1.6'], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
['5', 'v1.7'], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
['6', 'v1.9'], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
['7', 'v1.10'], // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
['8', 'v1.11'], // b28e696516a7f0c7a6ead6da967590ce6c1d6698 (#71486)
['9', 'v1.11'], // 8757e21c6a4fc00e83539aa7f9c28eb11eff53c1 (#72051)
['10', 'v1.12'] // 4f8b986e28736b59bc46cd0873a0f36fdaa6f5b8 (#61439)
]);
value = value.toString();
if (!versions.has(value)) {
throw new pytorch.Error(`Unsupported '${name}' version '${value}'.`);
}
return `${name} ${versions.get(value)}`;
}
static weights(obj) {
let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
if (type === 'torch.jit._script.RecursiveScriptModule') {
type = obj._c._type();
const target = {};
for (let i = 0; i < type.numAttributes(); i++) {
const k = type.getAttributeName(i);
target[k] = obj.__getattr__(k);
}
type = obj._c.qualified_name;
obj = target;
} else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') {
return null;
}
if (pytorch.Utility.isTensor(obj)) {
return null;
}
if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) {
const entries = Object.entries(obj);
const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pytorch.Utility.isTensor(value));
if (named.length > 0 && (named.length / entries.length) >= 0.8) {
obj = new Map(entries);
}
}
if (obj instanceof Map) {
const entries = Array.from(obj).filter(([name]) => name !== '_metadata');
const names = entries.filter(([name]) => typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1));
if (names.length > 1 && (names.length / entries.length) >= 0.8 &&
(entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) &&
(!entries.every(([, value]) => Array.isArray(value)))) {
const modules = new Map();
for (const [name, value] of entries) {
const separator = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.';
const path = name.split(separator);
let property = path.pop();
if (path.length > 1 && path[path.length - 1] === '_packed_params') {
property = `${path.pop()}.${property}`;
}
const key = path.join(separator);
if (!modules.has(key)) {
modules.set(key, {});
}
const module = modules.get(key);
if (pytorch.Utility.isTensor(value)) {
value.__name__ = name;
}
module[property] = value;
}
return modules;
}
}
if (obj && !Array.isArray(obj) && Object(obj) === obj) {
const modules = new Map();
const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
if (entries.length > 0 && entries) {
for (const [key, value] of entries) {
const name = key.toString();
if (!value || Object(value) !== value || pytorch.Utility.isTensor(value) || ArrayBuffer.isView(value) || value._modules instanceof Map) {
return null;
}
if (!modules.has(name)) {
modules.set(name, {});
}
const module = modules.get(name);
let tensor = false;
const entries = value instanceof Map ? value : new Map(Object.entries(value));
for (const [name, value] of entries) {
if (typeof name !== 'string') {
return null;
}
if (name.indexOf('.') !== -1) {
return null;
}
if (name === '_metadata') {
continue;
}
if (typeof value === 'string' || typeof value === 'number') {
module[name] = value;
continue;
}
if (pytorch.Utility.isTensor(value)) {
value.__name__ = name;
module[name] = value;
tensor = true;
}
}
if (!tensor) {
return null;
}
}
return modules;
}
}
return null;
}
static isMetadataObject(obj) {
if (pytorch.Utility.isInstance(obj, 'collections.OrderedDict')) {
for (const value of obj.values()) {
if (pytorch.Utility.isInstance(value, 'builtins.dict')) {
const entries = Array.from(value);
if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) {
return false;
}
}
}
return true;
}
return false;
}
};
nnapi.SerializedModel = class {
constructor(serialized_model, buffers) {
const reader = base.BinaryReader.open(serialized_model);
this.version = reader.int32();
if (this.version !== 1) {
throw new pytorch.Error('Invalid NNAPI serialized model version.');
}
const operands = new Array(reader.int32());
const values = new Array(reader.int32());
this.operations = new Array(reader.int32());
this.inputs = new Array(reader.int32());
this.outputs = new Array(reader.int32());
const data_types = new Map([
[0, 'float32'],
[1, 'int32'],
[2, 'uint32'],
[3, 'float32[]'],
[4, 'int32[]'],
[5, 'quant8_asymm[]'],
[6, 'boolean'],
[7, 'quant16_symm[]'],
[8, 'float16[]'],
[9, 'boolean[]'],
[10, 'float16'],
[11, 'quant8_symm_per_channel[]'],
[12, 'quant16_asymm[]'],
[13, 'quant8_symm[]'],
[14, 'quant8_asymm_signed[]'],
[16, 'model']
]);
for (let i = 0; i < operands.length; i++) {
const data_type = reader.int32();
operands[i] = {
index: i,
data_type: data_types.has(data_type) ? data_types.get(data_type) : data_type,
dimensions: new Array(reader.uint32()),
scale: reader.float32(),
zero_point: reader.int32()
};
}
for (let i = 0; i < values.length; i++) {
values[i] = {
index: reader.int32(),
source_type: reader.int32(),
source_length: reader.uint32()
};
}
for (let i = 0; i < this.operations.length; i++) {
this.operations[i] = {
index: reader.int32(),
identifier: i,
inputs: new Array(reader.uint32()),
outputs: new Array(reader.uint32())
};
}
for (const operand of operands) {
for (let i = 0; i < operand.dimensions.length; i++) {
operand.dimensions[i] = reader.uint32();
}
}
for (const value of values) {
const index = value.index;
const operand = operands[index];
switch (value.source_type) {
case 0: { // immediate
switch (operand.data_type) {
case 'boolean':
operand.value = reader.byte() ? true : false;
reader.skip(3);
break;
case 'int32':
operand.value = reader.int32();
break;
case 'float32':
operand.value = reader.float32();
break;
case 'int32[]':
operand.data = reader.read(value.source_length);
break;
case 'float32[]':
operand.data = reader.read(value.source_length);
break;
default:
throw new pytorch.Error(`Unsupported NNAPI operand type '${operand.data_type}'.`);
}
break;
}
case 2: { // numbered buffer
if (value.source_length !== 12) {
throw new pytorch.Error('Invalid NNAPI numbered buffer source length.');
}
const number = reader.uint32();
const offset = reader.uint32();
const operand_length = reader.uint32();
if (number < buffers.length && buffers[number].data) {
const storage = buffers[number];
const data = storage.data && storage.data.peek ? storage.data.peek() : storage.data;
operand.data = data.slice(offset, operand_length);
}
break;
}
case 3: { // numbered memory
throw new pytorch.Error('NNAPI numbered memory buffer not implemented.');
}
default: {
throw new pytorch.Error('Unsupported NNAPI value source type.');
}
}
}
for (const operation of this.operations) {
for (let i = 0; i < operation.inputs.length; i++) {
const index = reader.uint32();
operation.inputs[i] = operands[index];
}
for (let i = 0; i < operation.outputs.length; i++) {
const index = reader.uint32();
operation.outputs[i] = operands[index];
}
}
for (let i = 0; i < this.inputs.length; i++) {
const index = reader.uint32();
this.inputs[i] = operands[index];
}
for (let i = 0; i < this.outputs.length; i++) {
const index = reader.uint32();
this.outputs[i] = operands[index];
}
if (reader.position !== reader.length) {
throw new pytorch.Error('Invalid NNAPI serialized model length.');
}
}
};
nnapi.Graph = class {
constructor(model) {
this.name = 'torch.classes._nnapi.Compilation';
this.nodes = [];
this.inputs = [];
this.outputs = [];
const values = new Map();
values.map = (operand) => {
if (!values.has(operand.index)) {
const name = operand.index.toString();
const dimensions = operand.dimensions;
const shape = new pytorch.TensorShape(dimensions);
let dataType = operand.data_type.replace('[]', '');
let quantization = null;
switch (dataType) {
case 'quant8_asymm':
case 'quant8_symm_per_channel':
case 'quant8_symm':
case 'quant8_asymm_signed[]':
case 'quant16_asymm':
case 'quant16_symm':
quantization = dataType;
dataType = dataType.indexOf('16') === -1 ? 'uint8' : 'uint16';
break;
default:
break;
}
const type = new pytorch.TensorType(dataType, shape);
let initializer = null;
if (operand.data) {
const size = dimensions.reduce((a, b) => a * b, 1);
const tensor = {
dtype: { __reduce__: () => dataType },
size: () => dimensions,
stride: () => null,
storage_offset: () => 0,
storage: () => ({
dtype: { __reduce__: () => type.dataType },
data: operand.data, size: () => size
})
};
const context = new pytorch.Context();
initializer = new pytorch.Tensor(context, null, tensor);
}
if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) {
quantization = {
type: quantization || 'linear',
scale: [operand.scale],
offset: [operand.zero_point]
};
}
const value = new pytorch.Value(name, type, quantization, initializer);
values.set(operand.index, value);
}
return values.get(operand.index);
};
const metadata = new nnapi.Metadata();
for (const operation of model.operations) {
const node = new nnapi.Node(metadata, operation, values);
this.nodes.push(node);
}
for (let i = 0; i < model.inputs.length; i++) {
const name = i.toString();
const operand = model.inputs[i];
const argument = new pytorch.Argument(name, [values.map(operand)]);
this.inputs.push(argument);
}
for (let i = 0; i < model.outputs.length; i++) {
const name = i.toString();
const operand = model.outputs[i];
const argument = new pytorch.Argument(name, [values.map(operand)]);
this.outputs.push(argument);
}
}
};
nnapi.Node = class {
constructor(metadata, operation, values) {
const signature = (operation.inputs || []).map((input) => input.data_type);
this.name = '';
this.type = metadata.type(operation.index, signature);
this.inputs = [];
this.outputs = [];
this.attributes = [];
this.chain = [];
if (operation.identifier !== undefined) {
this.identifier = operation.identifier.toString();
}
if (Array.isArray(operation.inputs)) {
const inputs = this.type.inputs;
for (let i = 0; i < operation.inputs.length; i++) {
const name = i < inputs.length ? inputs[i].name : i.toString();
const operand = operation.inputs[i];
if (operand.dimensions.length > 0) {
const value = values.map(operand);
const argument = new pytorch.Argument(name, [value]);
this.inputs.push(argument);
} else if (name === 'activation') {
const activation = new Map([[1, 19], [2, 20], [3, 21]]).get(operand.value) || 0;
if (activation !== 0) {
this.chain.push(new nnapi.Node(metadata, { index: activation }));
}
} else {
const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false);
this.inputs.push(attribute);
}
}
}
if (Array.isArray(operation.outputs)) {
const outputs = this.type.outputs;
for (let i = 0; i < operation.outputs.length; i++) {
const name = i < outputs.length ? outputs[i].name : i.toString();
const operand = operation.outputs[i];
const value = values.map(operand);
const argument = new pytorch.Argument(name, [value]);
this.outputs.push(argument);
}
}
}
};
nnapi.Metadata = class {
constructor() {
this._types = new Map();
// https://developer.android.com/ndk/reference/group/neural-networks
// https://github.com/pytorch/pytorch/commits/master/torch/backends/_nnapi/serializer.py
this.register(0, 'ADD', '', ['A', 'B'], [['activation', 'int32']], ['C']);
this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
this.register(2, 'CONCATENATION');
this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
this.register(5, 'DEPTH_TO_SPACE');
this.register(6, 'DEQUANTIZE');
this.register(7, 'EMBEDDING_LOOKUP');
this.register(8, 'FLOOR');
this.register(9, 'FULLY_CONNECTED', 'Layer', ['input', 'weights', 'bias'], [['activation', 'int32']], ['output']);
this.register(10, 'HASHTABLE_LOOKUP');
this.register(11, 'L2_NORMALIZATION');
this.register(12, 'L2_POOL_2D', 'Pool');
this.register(13, 'LOCAL_RESPONSE_NORMALIZATION');
this.register(14, 'LOGISTIC');
this.register(15, 'LSH_PROJECTION');
this.register(16, 'LSTM', 'Layer');
this.register(17, 'MAX_POOL_2D', 'Pool');
this.register(18, 'MUL');
this.register(19, 'RELU', 'Activation', ['input'], [], ['output']);
this.register(20, 'RELU1', 'Activation');
this.register(21, 'RELU6', 'Activation');
this.register(22, 'RESHAPE', 'Shape', ['input', 'shape'], [], ['output']);
this.register(23, 'RESIZE_BILINEAR');
this.register(24, 'RNN', 'Layer');
this.register(25, 'SOFTMAX', 'Activation');
this.register(26, 'SPACE_TO_DEPTH');
this.register(27, 'SVDF');
this.register(28, 'TANH');
this.register(29, 'BATCH_TO_SPACE_ND');
this.register(30, 'DIV');
this.register(31, 'MEAN');
this.register(32, 'PAD');
this.register(33, 'SPACE_TO_BATCH_ND');
this.register(34, 'SQUEEZE');
this.register(35, 'STRIDED_SLICE');
this.register(36, 'SUB');
this.register(37, 'TRANSPOSE');
this.register(38, 'ABS');
this.register(39, 'ARGMAX');
this.register(40, 'ARGMIN');
this.register(41, 'AXIS_ALIGNED_BBOX_TRANSFORM');
this.register(42, 'BIDIRECTIONAL_SEQUENCE_LSTM');
this.register(43, 'BIDIRECTIONAL_SEQUENCE_RNN');
this.register(44, 'BOX_WITH_NMS_LIMIT');
this.register(45, 'CAST');
this.register(46, 'CHANNEL_SHUFFLE');
this.register(47, 'DETECTION_POSTPROCESSING');
this.register(48, 'EQUAL');
this.register(49, 'EXP');
this.register(50, 'EXPAND_DIMS');
this.register(51, 'GATHER');
this.register(52, 'GENERATE_PROPOSALS');
this.register(53, 'GREATER');
this.register(54, 'GREATER_EQUAL');
this.register(55, 'GROUPED_CONV_2D');
this.register(56, 'HEATMAP_MAX_KEYPOINT');
this.register(57, 'INSTANCE_NORMALIZATION');
this.register(58, 'LESS');
this.register(59, 'LESS_EQUAL');
this.register(60, 'LOG');
this.register(61, 'LOGICAL_AND');
this.register(62, 'LOGICAL_NOT');
this.register(63, 'LOGICAL_OR');
this.register(64, 'LOG_SOFTMAX');
this.register(65, 'MAXIMUM');
this.register(66, 'MINIMUM');
this.register(67, 'NEG');
this.register(68, 'NOT_EQUAL');
this.register(69, 'PAD_V2');
this.register(70, 'POW');
this.register(71, 'PRELU');
this.register(72, 'QUANTIZE');
this.register(73, 'QUANTIZED_16BIT_LSTM');
this.register(74, 'RANDOM_MULTINOMIAL');
this.register(75, 'REDUCE_ALL');
this.register(76, 'REDUCE_ANY');
this.register(77, 'REDUCE_MAX');
this.register(78, 'REDUCE_MIN');
this.register(79, 'REDUCE_PROD');
this.register(80, 'REDUCE_SUM');
this.register(81, 'ROI_ALIGN');
this.register(82, 'ROI_POOLING');
this.register(83, 'RSQRT');
this.register(84, 'SELECT');
this.register(85, 'SIN');
this.register(86, 'SLICE');
this.register(87, 'SPLIT');
this.register(88, 'SQRT');
this.register(89, 'TILE');
this.register(90, 'TOPK_V2');
this.register(91, 'TRANSPOSE_CONV_2D', 'Layer');
this.register(92, 'UNIDIRECTIONAL_SEQUENCE_LSTM', 'Layer');
this.register(93, 'UNIDIRECTIONAL_SEQUENCE_RNN', 'Layer');
this.register(94, 'RESIZE_NEAREST_NEIGHBOR');
this.register(95, 'QUANTIZED_LSTM', 'Layer');
this.register(96, 'IF');
this.register(97, 'WHILE');
this.register(98, 'ELU', 'Activation');
this.register(99, 'HARD_SWISH', 'Activation');
this.register(100, 'FILL');
this.register(101, 'RANK');
}
register(index, name, category, inputs, attributes, outputs) {
inputs = inputs || [];
outputs = outputs || [];
attributes = attributes || [];
const type = {};
type.name = name;
type.inputs = inputs.map((name) => ({ name, type: 'Tensor' }));
type.inputs = type.inputs.concat(attributes.map(([name, type]) => ({ name, type })));
type.outputs = outputs.map((name) => ({ name, type: 'Tensor' }));
if (category) {
type.category = category;
}
if (!this._types.has(index)) {
this._types.set(index, []);
}
this._types.get(index).push(type);
}
type(index, signature) {
if (!this._types.has(index)) {
this._types.set(index, { name: index.toString(), inputs: [], outputs: [], attributes: [] });
}
const types = this._types.get(index);
for (const type of types) {
const inputs = type.inputs;
if (signature.length < inputs.length) {
if (inputs.every((input, i) => input.type === undefined || input.type === 'Tensor' || input.type === signature[i])) {
return type;
}
}
}
return types[0];
}
};
pytorch.Metadata = class {
static async open(context) {
if (!pytorch.Metadata._metadata) {
let data = null;
try {
data = await context.request('pytorch-metadata.json');
} catch {
// continue regardless of error
}
pytorch.Metadata._metadata = new pytorch.Metadata(data);
}
return pytorch.Metadata._metadata;
}
constructor(data) {
this._types = new Map();
this._attributes = new Map();
this._index = new Map();
if (data) {
const items = JSON.parse(data);
for (const item of items) {
const index = item.name.indexOf('(');
const key = index === -1 ? item.name : item.name.substring(0, index);
this._types.set(key, item);
}
}
}
add(name, value) {
this._types.set(name, value);
}
type(name) {
return this._types.get(name);
}
attribute(type, name) {
const key = `${type}:${name}`;
if (!this._attributes.has(key)) {
this._attributes.set(key, null);
const metadata = this.type(type);
if (metadata) {
if (metadata.inputs) {
for (const input of metadata.inputs) {
this._attributes.set(`${type}:${input.name}`, input);
}
}
if (metadata.attributes) {
for (const attribute of metadata.attributes) {
this._attributes.set(`${type}:${attribute.name}`, attribute);
}
}
}
}
return this._attributes.get(key);
}
register(execution) {
const torch = execution.register('torch');
const registry = torch._C.getRegistry();
const modules = new Set();
for (const [name, type] of this._types) {
if (name.indexOf('::') !== -1) {
const schema = torch.FunctionSchema.parse(type.name);
if (type.category) {
schema.category = type.category;
}
schema.setAliasAnalysis('FROM_SCHEMA');
const op = new torch._C.Operator(schema);
registry.registerOperator(op);
modules.add(type.name.split('::')[0]);
}
}
for (const module of modules) {
const namespace = new torch._ops._OpNamespace(module);
execution.register(`torch.ops.${module}`, namespace);
}
}
};
numpy.Tensor = class {
constructor(array) {
this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
this.stride = array.strides.map((stride) => stride / array.itemsize);
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
}
};
numpy.TensorType = class {
constructor(dataType, shape) {
this.dataType = dataType || '?';
this.shape = shape;
}
toString() {
return this.dataType + this.shape.toString();
}
};
numpy.TensorShape = class {
constructor(dimensions) {
this.dimensions = dimensions;
}
toString() {
return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
}
};
pytorch.Error = class extends Error {
constructor(message) {
super(message);
this.name = 'Error loading PyTorch model.';
}
};
export const Metadata = pytorch.Metadata;
export const ModelFactory = pytorch.ModelFactory;