| using System; |
| using System.Collections.Generic; |
| using System.Linq; |
| using Unity.Sentis; |
| using UnityEngine; |
|
|
| public sealed class DebertaV3 : MonoBehaviour |
| { |
| public ModelAsset model; |
| public TextAsset vocabulary; |
| public bool multipleTrueClasses; |
| public string text = "Angela Merkel is a politician in Germany and leader of the CDU"; |
| public string hypothesisTemplate = "This example is about {}"; |
| public string[] classes = { "politics", "economy", "entertainment", "environment" }; |
|
|
| IWorker engine; |
| string[] vocabularyTokens; |
|
|
| const int padToken = 0; |
| const int startToken = 1; |
| const int separatorToken = 2; |
| const int vocabToTokenOffset = 260; |
|
|
| void Start() |
| { |
| if (classes.Length == 0) |
| { |
| Debug.LogError("There need to be more than 0 classes"); |
| return; |
| } |
|
|
| vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n"); |
|
|
| Model baseModel = ModelLoader.Load(model); |
| Model modelWithScoring = Functional.Compile( |
| input => |
| { |
| |
| |
| |
| FunctionalTensor logits = baseModel.Forward(input)[0]; |
|
|
| if (multipleTrueClasses || classes.Length == 1) |
| { |
| |
| logits = Functional.Softmax(logits); |
| } |
| else |
| { |
| |
| logits = Functional.Softmax(logits, 0); |
| } |
|
|
| |
| return new []{logits[.., 0]}; |
| }, |
| InputDef.FromModel(baseModel) |
| ); |
|
|
| engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring); |
|
|
| string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray(); |
| Batch batch = GetTokenizedBatch(text, hypotheses); |
| float[] scores = GetBatchScores(batch); |
|
|
| for (int i = 0; i < scores.Length; i++) |
| { |
| Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}"); |
| } |
| } |
|
|
| float[] GetBatchScores(Batch batch) |
| { |
| using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens); |
| using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks); |
|
|
| Dictionary<string, Tensor> inputs = new() |
| { |
| {"input_0", inputIds}, |
| {"input_1", attentionMask} |
| }; |
|
|
| engine.Execute(inputs); |
| TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0"); |
| scores.CompleteOperationsAndDownload(); |
|
|
| return scores.ToReadOnlyArray(); |
| } |
|
|
| Batch GetTokenizedBatch(string prompt, string[] hypotheses) |
| { |
| Batch batch = new Batch(); |
|
|
| List<int> promptTokens = Tokenize(prompt); |
| promptTokens.Insert(0, startToken); |
|
|
| List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray(); |
| int maxTokenLength = tokenizedHypotheses.Max(x => x.Count); |
|
|
| |
| |
|
|
| int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens |
| .Append(separatorToken) |
| .Concat(hypothesis) |
| .Append(separatorToken) |
| .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count))) |
| .ToArray(); |
|
|
|
|
| |
| |
|
|
| int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1) |
| .Concat(Enumerable.Repeat(1, hypothesis.Count + 1)) |
| .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count))) |
| .ToArray(); |
|
|
| batch.BatchCount = hypotheses.Length; |
| batch.BatchLength = batchedTokens.Length / hypotheses.Length; |
| batch.BatchedTokens = batchedTokens; |
| batch.BatchedMasks = batchedMasks; |
|
|
| return batch; |
| } |
|
|
| List<int> Tokenize(string input) |
| { |
| string[] words = input.Split(null); |
|
|
| List<int> ids = new(); |
|
|
| foreach (string word in words) |
| { |
| int start = 0; |
| for(int i = word.Length; i >= 0;i--) |
| { |
| string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start); |
| int index = Array.IndexOf(vocabularyTokens, subWord); |
| if (index >= 0) |
| { |
| ids.Add(index + vocabToTokenOffset); |
| if (i == word.Length) break; |
| start = i; |
| i = word.Length + 1; |
| } |
| } |
| } |
|
|
| return ids; |
| } |
|
|
| void OnDestroy() => engine?.Dispose(); |
|
|
| struct Batch |
| { |
| public int BatchCount; |
| public int BatchLength; |
| public int[] BatchedTokens; |
| public int[] BatchedMasks; |
| } |
| } |