Spaces:
Runtime error
Runtime error
| import * as _ from 'lodash' | |
| import * as x_ from '../etc/_Tools' | |
| import * as tp from '../etc/types' | |
| import * as tf from '@tensorflow/tfjs' | |
| /** | |
| * Notes: | |
| * | |
| * - Also encapsulate the CLS/SEP info vs. no CLS/SEP info | |
| * - When layer format changes from list, drop the index into conf.layer | |
| */ | |
| const bpeTokens = ["[CLS]", "[SEP]", "<s>", "</s>", "<|endoftext|>"] | |
| const findBadIndexes = (x: tp.FullSingleTokenInfo[]) => x_.findAllIndexes(x.map(t => t.text), (a) => _.includes(bpeTokens, a)) | |
| export function makeFromMetaResponse(r:tp.AttentionResponse, isZeroed){ | |
| const key = 'aa' // Change this if backend response changes to be simpler | |
| const currPair = r[key] | |
| const left = <tp.FullSingleTokenInfo[]>currPair.left | |
| const right = <tp.FullSingleTokenInfo[]>currPair.right | |
| const leftZero = x_.findAllIndexes(left.map(t => t.text), (a) => _.includes(bpeTokens, a)) | |
| const rightZero = x_.findAllIndexes(right.map(t => t.text), (a) => _.includes(bpeTokens, a)) | |
| return new AttentionWrapper(currPair.att, [leftZero, rightZero], isZeroed) | |
| } | |
| export class AttentionWrapper { | |
| protected _att:number[][][] | |
| protected _attTensor:tf.Tensor3D | |
| protected _zeroedAttTensor:tf.Tensor3D | |
| badToks:[number[], number[]] // Indexes for the CLS and SEP tokens | |
| isZeroed: boolean | |
| nLayers = 12; | |
| nHeads = 12; | |
| constructor(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed=true){ | |
| this.init(att, badToks, isZeroed) | |
| } | |
| init(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed) { | |
| this.isZeroed = isZeroed | |
| this._att = att; | |
| this._zeroedAttTensor = zeroRowCol(tf.tensor3d(att), badToks[0], badToks[1]) | |
| this._attTensor = tf.tensor3d(att) // If I put this first, buffer modifications change this too. | |
| this.badToks = badToks; | |
| } | |
| updateFromNormal(r:tp.AttentionResponse, isZeroed){ | |
| const key = 'aa' // Change this if backend response changes to be simpler | |
| const currPair = r[key] | |
| const left = <tp.FullSingleTokenInfo[]>currPair.left | |
| const right = <tp.FullSingleTokenInfo[]>currPair.right | |
| const leftZero = findBadIndexes(left) | |
| const rightZero = findBadIndexes(right) | |
| this.init(currPair.att, [leftZero, rightZero], isZeroed) | |
| } | |
| get attTensor() { | |
| const tens = this.isZeroed ? this._zeroedAttTensor : this._attTensor | |
| return tens | |
| } | |
| get att() { | |
| return this.attTensor.arraySync() | |
| } | |
| zeroed(): boolean | |
| zeroed(val:boolean): this | |
| zeroed(val?) { | |
| if (val == null) return this.isZeroed | |
| this.isZeroed = val | |
| return this | |
| } | |
| toggleZeroing() { | |
| this.zeroed(!this.zeroed()) | |
| } | |
| protected _byHeads(heads:number[]):tf.Tensor2D { | |
| if (heads.length == 0) { | |
| return tf.zerosLike(this._byHead(0)) | |
| } | |
| return (<tf.Tensor2D>this.attTensor.gather(heads, 0).sum(0)) | |
| } | |
| protected _byHead(head:number):tf.Tensor2D { | |
| return (<tf.Tensor2D>this.attTensor.gather([head], 0).squeeze([0])) | |
| } | |
| byHeads(heads:number[]):number[][] { | |
| return this._byHeads(heads).arraySync() | |
| } | |
| byHead(head:number):number[][] { | |
| return this._byHead(head).arraySync() | |
| } | |
| } | |
| function zeroRowCol(tens:tf.Tensor3D, rows:number[], cols:number[]):tf.Tensor3D { | |
| let outTens = tens.clone() | |
| let atb = outTens.bufferSync() | |
| _.range(atb.shape[0]).forEach((head) => { | |
| _.range(atb.shape[1]).forEach((i) => { | |
| // Set rows to 0 | |
| if (_.includes(rows, i)) { | |
| _.range(atb.shape[2]).forEach((j) => { | |
| atb.set(0, head, i, j) | |
| }) | |
| } | |
| // Set cols to 0 | |
| _.range(atb.shape[2]).forEach((j) => { | |
| if (_.includes(cols, j)) | |
| _.range(atb.shape[1]).forEach((i) => { | |
| atb.set(0, head, i, j) | |
| }) | |
| }) | |
| }) | |
| }) | |
| return outTens | |
| } |