idontwannna commited on
Commit
e30ea5a
·
verified ·
1 Parent(s): ca685d7

Debertav3.cs to 2.1.1 support

Browse files

had claudai make the changes to upgrade Debertav3.cs to Unity Sentis 2.0 (2.1.1) support.
it appears to be working...

Files changed (1) hide show
  1. DebertaV3.cs +147 -153
DebertaV3.cs CHANGED
@@ -6,157 +6,151 @@ using UnityEngine;
6
 
7
  public sealed class DebertaV3 : MonoBehaviour
8
  {
9
- public ModelAsset model;
10
- public TextAsset vocabulary;
11
- public bool multipleTrueClasses;
12
- public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
13
- public string hypothesisTemplate = "This example is about {}";
14
- public string[] classes = { "politics", "economy", "entertainment", "environment" };
15
-
16
- IWorker engine;
17
- string[] vocabularyTokens;
18
-
19
- const int padToken = 0;
20
- const int startToken = 1;
21
- const int separatorToken = 2;
22
- const int vocabToTokenOffset = 260;
23
-
24
- void Start()
25
- {
26
- if (classes.Length == 0)
27
- {
28
- Debug.LogError("There need to be more than 0 classes");
29
- return;
30
- }
31
-
32
- vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n");
33
-
34
- Model baseModel = ModelLoader.Load(model);
35
- Model modelWithScoring = Functional.Compile(
36
- input =>
37
- {
38
- // The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
39
- // They are of shape [batch size, 2] i.e. with two values per example.
40
- // To obtain a single score per example, a softmax function is applied
41
- FunctionalTensor logits = baseModel.Forward(input)[0];
42
-
43
- if (multipleTrueClasses || classes.Length == 1)
44
- {
45
- // Softmax over the entailment vs. contradiction dimension for each label independently
46
- logits = Functional.Softmax(logits);
47
- }
48
- else
49
- {
50
- // Softmax over all candidate labels
51
- logits = Functional.Softmax(logits, 0);
52
- }
53
-
54
- // The scores are stored along the first column
55
- return new []{logits[.., 0]};
56
- },
57
- InputDef.FromModel(baseModel)
58
- );
59
-
60
- engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring);
61
-
62
- string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
63
- Batch batch = GetTokenizedBatch(text, hypotheses);
64
- float[] scores = GetBatchScores(batch);
65
-
66
- for (int i = 0; i < scores.Length; i++)
67
- {
68
- Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
69
- }
70
- }
71
-
72
- float[] GetBatchScores(Batch batch)
73
- {
74
- using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
75
- using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
76
-
77
- Dictionary<string, Tensor> inputs = new()
78
- {
79
- {"input_0", inputIds},
80
- {"input_1", attentionMask}
81
- };
82
-
83
- engine.Execute(inputs);
84
- TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0");
85
- scores.CompleteOperationsAndDownload();
86
-
87
- return scores.ToReadOnlyArray();
88
- }
89
-
90
- Batch GetTokenizedBatch(string prompt, string[] hypotheses)
91
- {
92
- Batch batch = new Batch();
93
-
94
- List<int> promptTokens = Tokenize(prompt);
95
- promptTokens.Insert(0, startToken);
96
-
97
- List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
98
- int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
99
-
100
- // Each example in the batch follows this format:
101
- // Start Prompt Separator Hypothesis Separator Padding
102
-
103
- int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
104
- .Append(separatorToken)
105
- .Concat(hypothesis)
106
- .Append(separatorToken)
107
- .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
108
- .ToArray();
109
-
110
-
111
- // The attention masks have the same length as the tokens.
112
- // Each attention mask contains repeating 1s for each token, except for padding tokens.
113
-
114
- int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
115
- .Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
116
- .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
117
- .ToArray();
118
-
119
- batch.BatchCount = hypotheses.Length;
120
- batch.BatchLength = batchedTokens.Length / hypotheses.Length;
121
- batch.BatchedTokens = batchedTokens;
122
- batch.BatchedMasks = batchedMasks;
123
-
124
- return batch;
125
- }
126
-
127
- List<int> Tokenize(string input)
128
- {
129
- string[] words = input.Split(null);
130
-
131
- List<int> ids = new();
132
-
133
- foreach (string word in words)
134
- {
135
- int start = 0;
136
- for(int i = word.Length; i >= 0;i--)
137
- {
138
- string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
139
- int index = Array.IndexOf(vocabularyTokens, subWord);
140
- if (index >= 0)
141
- {
142
- ids.Add(index + vocabToTokenOffset);
143
- if (i == word.Length) break;
144
- start = i;
145
- i = word.Length + 1;
146
- }
147
- }
148
- }
149
-
150
- return ids;
151
- }
152
-
153
- void OnDestroy() => engine?.Dispose();
154
-
155
- struct Batch
156
- {
157
- public int BatchCount;
158
- public int BatchLength;
159
- public int[] BatchedTokens;
160
- public int[] BatchedMasks;
161
- }
162
  }
 
6
 
7
  public sealed class DebertaV3 : MonoBehaviour
8
  {
9
+ public ModelAsset model;
10
+ public TextAsset vocabulary;
11
+ public bool multipleTrueClasses;
12
+ public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
13
+ public string hypothesisTemplate = "This example is about {}";
14
+ public string[] classes = { "politics", "economy", "entertainment", "environment" };
15
+
16
+ Worker engine;
17
+ string[] vocabularyTokens;
18
+
19
+ const int padToken = 0;
20
+ const int startToken = 1;
21
+ const int separatorToken = 2;
22
+ const int vocabToTokenOffset = 260;
23
+
24
+ void Start()
25
+ {
26
+ if (classes.Length == 0)
27
+ {
28
+ Debug.LogError("There need to be more than 0 classes");
29
+ return;
30
+ }
31
+
32
+ vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n");
33
+
34
+ Model baseModel = ModelLoader.Load(model);
35
+
36
+ // Create the engine with the base model using the updated constructor
37
+ engine = new Worker(baseModel, BackendType.GPUCompute);
38
+
39
+ string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
40
+ Batch batch = GetTokenizedBatch(text, hypotheses);
41
+ float[] scores = GetBatchScores(batch);
42
+
43
+ for (int i = 0; i < scores.Length; i++)
44
+ {
45
+ Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
46
+ }
47
+ }
48
+
49
+ float[] GetBatchScores(Batch batch)
50
+ {
51
+ using var inputIds = new Tensor<int>(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
52
+ using var attentionMask = new Tensor<int>(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
53
+
54
+ // Schedule the execution with the inputs as array
55
+ engine.Schedule(new Tensor[] { inputIds, attentionMask });
56
+
57
+ // Get the output tensor
58
+ var output = engine.PeekOutput(0);
59
+ var scores = new float[batch.BatchCount];
60
+
61
+ // Get the raw data from tensor using the new method
62
+ if (output is Tensor<float> floatOutput)
63
+ {
64
+ var shape = floatOutput.shape;
65
+ var data = floatOutput.DownloadToArray();
66
+
67
+ // Apply softmax manually
68
+ for (int i = 0; i < batch.BatchCount; i++)
69
+ {
70
+ float val1 = data[i * 2];
71
+ float val2 = data[i * 2 + 1];
72
+ float maxVal = Math.Max(val1, val2);
73
+
74
+ float exp1 = (float)Math.Exp(val1 - maxVal);
75
+ float exp2 = (float)Math.Exp(val2 - maxVal);
76
+ float sum = exp1 + exp2;
77
+
78
+ scores[i] = exp1 / sum; // Normalized probability for the first class
79
+ }
80
+ }
81
+
82
+ return scores;
83
+ }
84
+
85
+ Batch GetTokenizedBatch(string prompt, string[] hypotheses)
86
+ {
87
+ Batch batch = new Batch();
88
+
89
+ List<int> promptTokens = Tokenize(prompt);
90
+ promptTokens.Insert(0, startToken);
91
+
92
+ List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
93
+ int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
94
+
95
+ // Each example in the batch follows this format:
96
+ // Start Prompt Separator Hypothesis Separator Padding
97
+
98
+ int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
99
+ .Append(separatorToken)
100
+ .Concat(hypothesis)
101
+ .Append(separatorToken)
102
+ .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
103
+ .ToArray();
104
+
105
+ // The attention masks have the same length as the tokens.
106
+ // Each attention mask contains repeating 1s for each token, except for padding tokens.
107
+
108
+ int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
109
+ .Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
110
+ .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
111
+ .ToArray();
112
+
113
+ batch.BatchCount = hypotheses.Length;
114
+ batch.BatchLength = batchedTokens.Length / hypotheses.Length;
115
+ batch.BatchedTokens = batchedTokens;
116
+ batch.BatchedMasks = batchedMasks;
117
+
118
+ return batch;
119
+ }
120
+
121
+ List<int> Tokenize(string input)
122
+ {
123
+ string[] words = input.Split(null);
124
+
125
+ List<int> ids = new();
126
+
127
+ foreach (string word in words)
128
+ {
129
+ int start = 0;
130
+ for(int i = word.Length; i >= 0; i--)
131
+ {
132
+ string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
133
+ int index = Array.IndexOf(vocabularyTokens, subWord);
134
+ if (index >= 0)
135
+ {
136
+ ids.Add(index + vocabToTokenOffset);
137
+ if (i == word.Length) break;
138
+ start = i;
139
+ i = word.Length + 1;
140
+ }
141
+ }
142
+ }
143
+
144
+ return ids;
145
+ }
146
+
147
+ void OnDestroy() => engine?.Dispose();
148
+
149
+ struct Batch
150
+ {
151
+ public int BatchCount;
152
+ public int BatchLength;
153
+ public int[] BatchedTokens;
154
+ public int[] BatchedMasks;
155
+ }
 
 
 
 
 
 
156
  }