nabeix commited on
Commit
db3e1cf
·
1 Parent(s): a2d01e0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +18 -0
  2. Dockerfile +42 -0
  3. README.md +13 -0
  4. __init__.py +4 -0
  5. data/figures/framework.jpeg +3 -0
  6. data/figures/inputids.png +3 -0
  7. data/figures/samples.png +3 -0
  8. data/figures/title_new.png +3 -0
  9. data/figures/training.jpeg +3 -0
  10. data/omni2-demo.mp4 +3 -0
  11. data/samples/output1.wav +0 -0
  12. data/samples/output2.wav +3 -0
  13. data/samples/output3.wav +0 -0
  14. data/samples/output4.wav +0 -0
  15. data/samples/output5.wav +3 -0
  16. data/samples/vision_qa_audio.wav +3 -0
  17. hotkey.txt +1 -0
  18. inference.py +710 -0
  19. inference_vision.py +263 -0
  20. litgpt/__init__.py +19 -0
  21. litgpt/config.py +181 -0
  22. litgpt/generate/__init__.py +0 -0
  23. litgpt/generate/base.py +795 -0
  24. litgpt/model.py +655 -0
  25. litgpt/tokenizer.py +132 -0
  26. litgpt/utils.py +641 -0
  27. models/README.md +143 -0
  28. models/ViT-B-32.pt +3 -0
  29. models/data/figures/framework.jpeg +3 -0
  30. models/data/figures/inputids.png +3 -0
  31. models/data/figures/samples.png +3 -0
  32. models/data/figures/title.png +3 -0
  33. models/data/figures/training.jpeg +3 -0
  34. models/data/omni2-demo.mp4 +3 -0
  35. models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock +0 -0
  36. models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock +0 -0
  37. models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 +3 -0
  38. models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55 +13 -0
  39. models/hub/models--hubertsiuzdak--snac_24khz/refs/main +1 -0
  40. models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json +13 -0
  41. models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin +3 -0
  42. models/hub/version.txt +1 -0
  43. models/lit_model.pth +3 -0
  44. models/model_config.yaml +43 -0
  45. models/small.pt +3 -0
  46. models/tokenizer.json +0 -0
  47. models/tokenizer_config.json +40 -0
  48. requirements.txt +18 -0
  49. server.py +164 -0
  50. utils/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audio_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
37
+ data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
39
+ data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
40
+ data/figures/title_new.png filter=lfs diff=lfs merge=lfs -text
41
+ data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
42
+ data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ data/samples/output2.wav filter=lfs diff=lfs merge=lfs -text
44
+ data/samples/output5.wav filter=lfs diff=lfs merge=lfs -text
45
+ data/samples/vision_qa_audio.wav filter=lfs diff=lfs merge=lfs -text
46
+ models/data/figures/framework.jpeg filter=lfs diff=lfs merge=lfs -text
47
+ models/data/figures/inputids.png filter=lfs diff=lfs merge=lfs -text
48
+ models/data/figures/samples.png filter=lfs diff=lfs merge=lfs -text
49
+ models/data/figures/title.png filter=lfs diff=lfs merge=lfs -text
50
+ models/data/figures/training.jpeg filter=lfs diff=lfs merge=lfs -text
51
+ models/data/omni2-demo.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 filter=lfs diff=lfs merge=lfs -text
53
+ vision_qa_out_cache.wav filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ DEBIAN_FRONTEND=noninteractive \
7
+ CUDA_HOME=/usr/local/cuda \
8
+ PATH=/usr/local/cuda/bin:$PATH \
9
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
10
+ NVIDIA_VISIBLE_DEVICES=all \
11
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility \
12
+ HF_HOME=/app/models
13
+
14
+ # Install system dependencies
15
+ RUN apt-get update && apt-get install -y --no-install-recommends \
16
+ python3 \
17
+ python3-pip \
18
+ python3-dev \
19
+ build-essential \
20
+ git \
21
+ ffmpeg \
22
+ libsndfile1 \
23
+ curl \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+
27
+ # Upgrade pip and install build tools
28
+ RUN python3 -m pip install --upgrade pip setuptools wheel uv
29
+
30
+ WORKDIR /app
31
+
32
+ COPY requirements.txt .
33
+
34
+
35
+ # Install other requirements
36
+ RUN python3 -m uv pip install --no-cache-dir -r requirements.txt --prerelease=allow
37
+
38
+ COPY . .
39
+
40
+ EXPOSE 8000
41
+
42
+ CMD ["python3", "server.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
data/figures/framework.jpeg ADDED

Git LFS Details

  • SHA256: bc668450030500a62ddbb7cf6ea170f0b53da7e3e5506d01a0dc6f2ec690fd1a
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
data/figures/inputids.png ADDED

Git LFS Details

  • SHA256: ad4cf663684c53f72952b13f52ea93fcbe19e287301b3decfcd917de9e23f312
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
data/figures/samples.png ADDED

Git LFS Details

  • SHA256: e63a8cbc2859304cb9c50b831366ac8804ad0326b6ae4897d08f8ab0e1eb63c6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
data/figures/title_new.png ADDED

Git LFS Details

  • SHA256: fd327145b6368a08a713164af9de7b1f9fc15a9077090586a1e65d915a82b538
  • Pointer size: 131 Bytes
  • Size of remote file: 355 kB
data/figures/training.jpeg ADDED

Git LFS Details

  • SHA256: fd49f75dbe5838a3e28f02c8f853dec34d0aad8573911d52bd827ab6dae8f9a1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
data/omni2-demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
3
+ size 11784395
data/samples/output1.wav ADDED
Binary file (62.2 kB). View file
 
data/samples/output2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b50c4df6f508a4367e5a49e90f974f8786c6d9ffb2599a8abcd25e693399735a
3
+ size 105176
data/samples/output3.wav ADDED
Binary file (70.4 kB). View file
 
data/samples/output4.wav ADDED
Binary file (67.6 kB). View file
 
data/samples/output5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33f58e7cc49a4e4fd4809d20cde2fb22855054cf61558be8ffef347fc35ce8f2
3
+ size 114732
data/samples/vision_qa_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18eba79742ad8074074a113e6df56410bdf66e34a645d619a4ad7b8171f6d7d7
3
+ size 150572
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5HQz3QrsWkgAYJRXF3xUQVWrjmLWiqpWsj5uv4bQdXFyGhCy
inference.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import glob
5
+ import time
6
+ from snac import SNAC
7
+ from litgpt import Tokenizer
8
+ from litgpt.utils import (
9
+ num_parameters,
10
+ )
11
+
12
+ from litgpt.generate.base import (
13
+ generate_AA,
14
+ generate_ASR,
15
+ generate_TA,
16
+ generate_TT,
17
+ generate_AT,
18
+ generate_TA_BATCH,
19
+ next_token_image_batch
20
+ )
21
+
22
+ # print(hihi)
23
+ import soundfile as sf
24
+ from litgpt.model import GPT, Config
25
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
26
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
27
+ from utils.snac_utils import get_snac, generate_audio_data
28
+ import whisper
29
+ from tqdm import tqdm
30
+
31
+ from huggingface_hub import snapshot_download
32
+
33
+
34
+ torch.set_printoptions(sci_mode=False)
35
+
36
+
37
+ # TODO
38
+ text_vocabsize = 151936
39
+ text_specialtokens = 64
40
+ audio_vocabsize = 4096
41
+ audio_specialtokens = 64
42
+
43
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
44
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
45
+
46
+ _eot = text_vocabsize
47
+ _pad_t = text_vocabsize + 1
48
+ _input_t = text_vocabsize + 2
49
+ _answer_t = text_vocabsize + 3
50
+ _asr = text_vocabsize + 4
51
+
52
+ _eoa = audio_vocabsize
53
+ _pad_a = audio_vocabsize + 1
54
+ _input_a = audio_vocabsize + 2
55
+ _answer_a = audio_vocabsize + 3
56
+ _split = audio_vocabsize + 4
57
+ _image = audio_vocabsize + 5
58
+ _eoimage = audio_vocabsize + 6
59
+
60
+
61
+ def get_input_ids_TA(text, text_tokenizer):
62
+ input_ids_item = [[] for _ in range(8)]
63
+ text_tokens = text_tokenizer.encode(text)
64
+ for i in range(7):
65
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
66
+ layershift(_answer_a, i)
67
+ ]
68
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
69
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
70
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
71
+ return input_ids_item
72
+
73
+
74
+ def get_input_ids_TT(text, text_tokenizer):
75
+ input_ids_item = [[] for i in range(8)]
76
+ text_tokens = text_tokenizer.encode(text).tolist()
77
+
78
+ for i in range(7):
79
+ input_ids_item[i] = torch.tensor(
80
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
81
+ ).unsqueeze(0)
82
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
83
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
84
+
85
+ return input_ids_item
86
+
87
+
88
+ def get_input_ids_whisper(
89
+ mel, leng, whispermodel, device,
90
+ special_token_a=_answer_a, special_token_t=_answer_t,
91
+ ):
92
+
93
+ with torch.no_grad():
94
+ mel = mel.unsqueeze(0).to(device)
95
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
96
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
97
+
98
+ T = audio_feature.size(0)
99
+ input_ids = []
100
+ for i in range(7):
101
+ input_ids_item = []
102
+ input_ids_item.append(layershift(_input_a, i))
103
+ input_ids_item += [layershift(_pad_a, i)] * T
104
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
105
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
106
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
107
+ input_ids.append(input_id_T.unsqueeze(0))
108
+ return audio_feature.unsqueeze(0), input_ids
109
+
110
+
111
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
112
+ with torch.no_grad():
113
+ mel = mel.unsqueeze(0).to(device)
114
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
115
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
116
+ T = audio_feature.size(0)
117
+ input_ids_AA = []
118
+ for i in range(7):
119
+ input_ids_item = []
120
+ input_ids_item.append(layershift(_input_a, i))
121
+ input_ids_item += [layershift(_pad_a, i)] * T
122
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
123
+ input_ids_AA.append(torch.tensor(input_ids_item))
124
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
125
+ input_ids_AA.append(input_id_T)
126
+
127
+ input_ids_AT = []
128
+ for i in range(7):
129
+ input_ids_item = []
130
+ input_ids_item.append(layershift(_input_a, i))
131
+ input_ids_item += [layershift(_pad_a, i)] * T
132
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
133
+ input_ids_AT.append(torch.tensor(input_ids_item))
134
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
135
+ input_ids_AT.append(input_id_T)
136
+
137
+
138
+ input_ids = [input_ids_AA, input_ids_AT]
139
+ stacked_inputids = [[] for _ in range(8)]
140
+ for i in range(2):
141
+ for j in range(8):
142
+ stacked_inputids[j].append(input_ids[i][j])
143
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
144
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
145
+
146
+
147
+ def load_audio(path):
148
+ audio = whisper.load_audio(path)
149
+ duration_ms = (len(audio) / 16000) * 1000
150
+ audio = whisper.pad_or_trim(audio)
151
+ mel = whisper.log_mel_spectrogram(audio)
152
+ return mel, int(duration_ms / 20) + 1
153
+
154
+
155
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
156
+ snacmodel, out_dir=None):
157
+ with fabric.init_tensor():
158
+ model.set_kv_cache(batch_size=2)
159
+ tokenlist = generate_TA_BATCH(
160
+ model,
161
+ audio_feature,
162
+ input_ids,
163
+ [leng, leng],
164
+ ["A1A2", "A1T2"],
165
+ max_returned_tokens=2048,
166
+ temperature=0.9,
167
+ top_k=1,
168
+ eos_id_a=_eoa,
169
+ eos_id_t=_eot,
170
+ pad_id_t=_pad_t,
171
+ shift=padded_text_vocabsize,
172
+ include_prompt=True,
173
+ generate_text=True,
174
+ )
175
+ text_tokenlist = tokenlist[-1]
176
+ if text_vocabsize in text_tokenlist:
177
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
178
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
179
+
180
+ audio_tokenlist = tokenlist[:-1]
181
+ audiolist = reconscruct_snac(audio_tokenlist)
182
+ audio = reconstruct_tensors(audiolist)
183
+ if out_dir is None:
184
+ out_dir = "./output/default/A1-A2-batch"
185
+ else:
186
+ out_dir = out_dir + "/A1-A2-batch"
187
+ if not os.path.exists(out_dir):
188
+ os.makedirs(out_dir)
189
+ with torch.inference_mode():
190
+ audio_hat = snacmodel.decode(audio)
191
+ sf.write(
192
+ f"{out_dir}/{step:02d}.wav",
193
+ audio_hat.squeeze().cpu().numpy(),
194
+ 24000,
195
+ )
196
+ model.clear_kv_cache()
197
+ return text
198
+
199
+
200
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
201
+ with fabric.init_tensor():
202
+ model.set_kv_cache(batch_size=1)
203
+ tokenlist = generate_AT(
204
+ model,
205
+ audio_feature,
206
+ input_ids,
207
+ [leng],
208
+ ["AT"],
209
+ max_returned_tokens=2048,
210
+ temperature=0.9,
211
+ top_k=1,
212
+ eos_id_a=_eoa,
213
+ eos_id_t=_eot,
214
+ pad_id_t=_pad_t,
215
+ shift=padded_text_vocabsize,
216
+ include_prompt=True,
217
+ generate_text=True,
218
+ )
219
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
220
+
221
+
222
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
223
+ snacmodel, out_dir=None):
224
+ with fabric.init_tensor():
225
+ model.set_kv_cache(batch_size=1)
226
+ tokenlist = generate_AA(
227
+ model,
228
+ audio_feature,
229
+ input_ids,
230
+ [leng],
231
+ ["A1T2"],
232
+ max_returned_tokens=2048,
233
+ temperature=0.9,
234
+ top_k=1,
235
+ eos_id_a=_eoa,
236
+ eos_id_t=_eot,
237
+ pad_id_t=_pad_t,
238
+ shift=padded_text_vocabsize,
239
+ include_prompt=True,
240
+ generate_text=True,
241
+ )
242
+ audiolist = reconscruct_snac(tokenlist)
243
+ tokenlist = tokenlist[-1]
244
+ if text_vocabsize in tokenlist:
245
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
246
+ if out_dir is None:
247
+ out_dir = "./output/default/A1-A2"
248
+ else:
249
+ out_dir = out_dir + "/A1-A2"
250
+ if not os.path.exists(out_dir):
251
+ os.makedirs(out_dir)
252
+
253
+ audio = reconstruct_tensors(audiolist)
254
+ with torch.inference_mode():
255
+ audio_hat = snacmodel.decode(audio)
256
+ sf.write(
257
+ f"{out_dir}/{step:02d}.wav",
258
+ audio_hat.squeeze().cpu().numpy(),
259
+ 24000,
260
+ )
261
+ model.clear_kv_cache()
262
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
263
+
264
+
265
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
266
+ with fabric.init_tensor():
267
+ model.set_kv_cache(batch_size=1)
268
+ tokenlist = generate_ASR(
269
+ model,
270
+ audio_feature,
271
+ input_ids,
272
+ [leng],
273
+ ["A1T1"],
274
+ max_returned_tokens=2048,
275
+ temperature=0.9,
276
+ top_k=1,
277
+ eos_id_a=_eoa,
278
+ eos_id_t=_eot,
279
+ pad_id_t=_pad_t,
280
+ shift=padded_text_vocabsize,
281
+ include_prompt=True,
282
+ generate_text=True,
283
+ )
284
+ model.clear_kv_cache()
285
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
286
+
287
+
288
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
289
+ snacmodel, out_dir=None):
290
+ with fabric.init_tensor():
291
+ model.set_kv_cache(batch_size=1)
292
+ tokenlist = generate_TA(
293
+ model,
294
+ None,
295
+ input_ids,
296
+ None,
297
+ ["T1A2"],
298
+ max_returned_tokens=2048,
299
+ temperature=0.9,
300
+ top_k=1,
301
+ eos_id_a=_eoa,
302
+ eos_id_t=_eot,
303
+ pad_id_t=_pad_t,
304
+ shift=padded_text_vocabsize,
305
+ include_prompt=True,
306
+ generate_text=True,
307
+ )
308
+
309
+ audiolist = reconscruct_snac(tokenlist)
310
+ tokenlist = tokenlist[-1]
311
+
312
+ if text_vocabsize in tokenlist:
313
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
314
+ audio = reconstruct_tensors(audiolist)
315
+ if out_dir is None:
316
+ out_dir = "./output/default/T1-A2"
317
+ else:
318
+ out_dir = out_dir + "/T1-A2"
319
+ if not os.path.exists(out_dir):
320
+ os.makedirs(out_dir)
321
+
322
+ with torch.inference_mode():
323
+ audio_hat = snacmodel.decode(audio)
324
+ sf.write(
325
+ f"{out_dir}/{step:02d}.wav",
326
+ audio_hat.squeeze().cpu().numpy(),
327
+ 24000,
328
+ )
329
+ model.clear_kv_cache()
330
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
331
+
332
+
333
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
334
+
335
+ with fabric.init_tensor():
336
+ model.set_kv_cache(batch_size=1)
337
+ tokenlist = generate_TT(
338
+ model,
339
+ None,
340
+ input_ids,
341
+ None,
342
+ ["T1T2"],
343
+ max_returned_tokens=2048,
344
+ temperature=0.9,
345
+ top_k=1,
346
+ eos_id_a=_eoa,
347
+ eos_id_t=_eot,
348
+ pad_id_t=_pad_t,
349
+ shift=padded_text_vocabsize,
350
+ include_prompt=True,
351
+ generate_text=True,
352
+ )
353
+ model.clear_kv_cache()
354
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
355
+
356
+
357
+ def load_model(ckpt_dir, device):
358
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
359
+ whisper_model_path = ckpt_dir + "/small.pt"
360
+ if not os.path.exists(whisper_model_path):
361
+ whisper_model_path = "small"
362
+ whispermodel = whisper.load_model(whisper_model_path).to(device)
363
+ text_tokenizer = Tokenizer(ckpt_dir)
364
+ fabric = L.Fabric(devices=1, strategy="auto")
365
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
366
+ config.post_adapter = False
367
+
368
+ with fabric.init_module(empty_init=False):
369
+ model = GPT(config)
370
+
371
+ model = fabric.setup(model)
372
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
373
+ model.load_state_dict(state_dict, strict=True)
374
+ model.to(device).eval()
375
+
376
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
377
+
378
+
379
+ def download_model(ckpt_dir):
380
+ repo_id = "gpt-omni/mini-omni2"
381
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
382
+
383
+
384
+ def get_text_stream(list_output, index, text_tokenizer):
385
+ text_tokens = list_output[-1][index:]
386
+ index += len(text_tokens)
387
+ is_text_end = False
388
+ if text_vocabsize in text_tokens:
389
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
390
+ is_text_end = True
391
+ if len(text_tokens) == 0:
392
+ return "", index, is_text_end
393
+ res_text = text_tokenizer.decode(torch.tensor(text_tokens))
394
+ return res_text, index, is_text_end
395
+
396
+
397
+ class OmniInference:
398
+
399
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
400
+ self.device = device
401
+ if not os.path.exists(ckpt_dir):
402
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
403
+ download_model(ckpt_dir)
404
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
405
+
406
+ def warm_up(self, sample='./data/samples/output1.wav'):
407
+ for _ in self.run_AT_batch_stream(sample):
408
+ pass
409
+
410
+ @torch.inference_mode()
411
+ def run_AT_batch_stream(self,
412
+ audio_path,
413
+ stream_stride=4,
414
+ max_returned_tokens=2048,
415
+ temperature=0.9,
416
+ top_k=1,
417
+ top_p=1.0,
418
+ eos_id_a=_eoa,
419
+ eos_id_t=_eot,
420
+ save_path=None,
421
+ sample_rate=24000,
422
+ ):
423
+
424
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
425
+ model = self.model
426
+
427
+ with self.fabric.init_tensor():
428
+ model.set_kv_cache(batch_size=2,device=self.device)
429
+
430
+ mel, leng = load_audio(audio_path)
431
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
432
+ T = input_ids[0].size(1)
433
+ device = input_ids[0].device
434
+
435
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
436
+
437
+ if model.max_seq_length < max_returned_tokens - 1:
438
+ raise NotImplementedError(
439
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
440
+ )
441
+
442
+ input_pos = torch.tensor([T], device=device)
443
+ list_output = [[] for i in range(8)]
444
+ tokens_A, token_T = next_token_image_batch(
445
+ model,
446
+ audio_feature.to(torch.float32).to(model.device),
447
+ None,
448
+ input_ids,
449
+ [T - 3, T - 3],
450
+ ["A1T2", "A1T2"],
451
+ input_pos=torch.arange(0, T, device=device),
452
+ temperature=temperature,
453
+ top_k=top_k,
454
+ top_p=top_p,
455
+ )
456
+
457
+ for i in range(7):
458
+ list_output[i].append(tokens_A[i].tolist()[0])
459
+ list_output[7].append(token_T.tolist()[0])
460
+
461
+ model_input_ids = [[] for i in range(8)]
462
+ for i in range(7):
463
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
464
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
465
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
466
+ model_input_ids[i] = torch.stack(model_input_ids[i])
467
+
468
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
469
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
470
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
471
+
472
+ text_end = False
473
+ index = 1
474
+ nums_generate = stream_stride
475
+ begin_generate = False
476
+ current_index = 0
477
+
478
+ text_index = 0
479
+ is_text_end = False
480
+
481
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
482
+ tokens_A, token_T = next_token_image_batch(
483
+ model,
484
+ None,
485
+ None,
486
+ model_input_ids,
487
+ None,
488
+ None,
489
+ input_pos=input_pos,
490
+ temperature=temperature,
491
+ top_k=top_k,
492
+ top_p=top_p,
493
+ )
494
+
495
+ if text_end:
496
+ token_T = torch.tensor([_pad_t], device=device)
497
+
498
+ if tokens_A[-1] == eos_id_a:
499
+ break
500
+
501
+ if token_T == eos_id_t:
502
+ text_end = True
503
+
504
+ for i in range(7):
505
+ list_output[i].append(tokens_A[i].tolist()[0])
506
+ list_output[7].append(token_T.tolist()[0])
507
+
508
+ model_input_ids = [[] for i in range(8)]
509
+ for i in range(7):
510
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
511
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
512
+ model_input_ids[i].append(
513
+ torch.tensor([layershift(4097, i)], device=device)
514
+ )
515
+ model_input_ids[i] = torch.stack(model_input_ids[i])
516
+
517
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
518
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
519
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
520
+
521
+ if index == 7:
522
+ begin_generate = True
523
+
524
+ if begin_generate:
525
+ current_index += 1
526
+ if current_index == nums_generate:
527
+ current_index = 0
528
+ snac = get_snac(list_output, index, nums_generate)
529
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
530
+ if is_text_end:
531
+ text_stream = ""
532
+ else:
533
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
534
+
535
+ yield (audio_stream, text_stream)
536
+
537
+ input_pos = input_pos.add_(1)
538
+ index += 1
539
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
540
+ print(f"text output: {text}")
541
+
542
+ if save_path is not None:
543
+ audiolist = reconscruct_snac(list_output)
544
+ audio = reconstruct_tensors(audiolist)
545
+ with torch.inference_mode():
546
+ audio_hat = self.snacmodel.decode(audio)
547
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), sample_rate)
548
+
549
+ model.clear_kv_cache()
550
+ return list_output
551
+
552
+
553
+ def test_infer():
554
+ device = "cuda:0"
555
+ out_dir = f"./output/{get_time_str()}"
556
+ ckpt_dir = f"./checkpoint"
557
+ if not os.path.exists(ckpt_dir):
558
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
559
+ download_model(ckpt_dir)
560
+
561
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
562
+
563
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
564
+
565
+ # prepare test data
566
+ # TODO
567
+ test_audio_list = sorted(glob.glob('./data/samples/output*.wav'))
568
+ test_audio_transcripts = [
569
+ "What is your name?",
570
+ "what are your hobbies?",
571
+ "Do you like beijing",
572
+ "How are you feeling today?",
573
+ "what is the weather like today?",
574
+ ]
575
+ test_text_list = [
576
+ "What is your name?",
577
+ "How are you feeling today?",
578
+ "Can you describe your surroundings?",
579
+ "What did you do yesterday?",
580
+ "What is your favorite book and why?",
581
+ "How do you make a cup of tea?",
582
+ "What is the weather like today?",
583
+ "Can you explain the concept of time?",
584
+ "Can you tell me a joke?",
585
+ ]
586
+
587
+ # LOAD MODEL
588
+ with torch.no_grad():
589
+ if "A1A2" in task:
590
+ print("===============================================================")
591
+ print(" testing A1A2")
592
+ print("===============================================================")
593
+ step = 0
594
+ for path in test_audio_list:
595
+ try:
596
+ mel, leng = load_audio(path)
597
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
598
+ text = A1_A2(
599
+ fabric,
600
+ audio_feature,
601
+ input_ids,
602
+ leng,
603
+ model,
604
+ text_tokenizer,
605
+ step,
606
+ snacmodel,
607
+ out_dir=out_dir,
608
+ )
609
+ print(f"input: {test_audio_transcripts[step]}")
610
+ print(f"output: {text}")
611
+ step += 1
612
+ print(
613
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
614
+ )
615
+ except:
616
+ print(f"[error] failed to process {path}")
617
+ print("===============================================================")
618
+
619
+ if 'asr' in task:
620
+ print("===============================================================")
621
+ print(" testing asr")
622
+ print("===============================================================")
623
+
624
+ index = 0
625
+ step = 0
626
+ for path in test_audio_list:
627
+ mel, leng = load_audio(path)
628
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
629
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
630
+ print(f"audio_path: {path}")
631
+ print(f"audio transcript: {test_audio_transcripts[index]}")
632
+ print(f"asr output: {output}")
633
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
634
+ index += 1
635
+
636
+ if "T1A2" in task:
637
+ step = 0
638
+ print("\n")
639
+ print("===============================================================")
640
+ print(" testing T1A2")
641
+ print("===============================================================")
642
+ for text in test_text_list:
643
+ input_ids = get_input_ids_TA(text, text_tokenizer)
644
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
645
+ snacmodel, out_dir=out_dir)
646
+ print(f"input: {text}")
647
+ print(f"output: {text_output}")
648
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
649
+ step += 1
650
+ print("===============================================================")
651
+
652
+ if "T1T2" in task:
653
+ step = 0
654
+ print("\n")
655
+ print("===============================================================")
656
+ print(" testing T1T2")
657
+ print("===============================================================")
658
+
659
+ for text in test_text_list:
660
+ input_ids = get_input_ids_TT(text, text_tokenizer)
661
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
662
+ print(f" Input: {text}")
663
+ print(f"Output: {text_output}")
664
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
665
+ print("===============================================================")
666
+
667
+ if "AT" in task:
668
+ print("===============================================================")
669
+ print(" testing A1T2")
670
+ print("===============================================================")
671
+ step = 0
672
+ for path in test_audio_list:
673
+ mel, leng = load_audio(path)
674
+ audio_feature, input_ids = get_input_ids_whisper(
675
+ mel, leng, whispermodel, device,
676
+ special_token_a=_pad_a, special_token_t=_answer_t
677
+ )
678
+ text = A1_T2(
679
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
680
+ )
681
+ print(f"input: {test_audio_transcripts[step]}")
682
+ print(f"output: {text}")
683
+ step += 1
684
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
685
+ print("===============================================================")
686
+
687
+ if "AA-BATCH" in task:
688
+ print("===============================================================")
689
+ print(" testing A1A2-BATCH")
690
+ print("===============================================================")
691
+ step = 0
692
+ for path in test_audio_list:
693
+ mel, leng = load_audio(path)
694
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
695
+ text = A1_A2_batch(
696
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
697
+ snacmodel, out_dir=out_dir
698
+ )
699
+ print(f"input: {test_audio_transcripts[step]}")
700
+ print(f"output: {text}")
701
+ step += 1
702
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
703
+ print("===============================================================")
704
+
705
+ print("*********************** test end *****************************")
706
+
707
+
708
+
709
+ if __name__ == "__main__":
710
+ test_infer()
inference_vision.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from litgpt.generate.base import next_token_image_batch
4
+ import soundfile as sf
5
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
6
+ from utils.snac_utils import get_snac, generate_audio_data
7
+ import clip
8
+ import inference
9
+ from tqdm import tqdm
10
+ from inference import OmniInference, load_model, load_audio, download_model
11
+ from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
12
+ from PIL import Image
13
+
14
+
15
+ torch.set_printoptions(sci_mode=False)
16
+
17
+
18
+ _image = inference._image
19
+ _eoimage = inference._eoimage
20
+ _pad_t = inference._pad_t
21
+ _input_t = inference._input_t
22
+ _answer_t = inference._answer_t
23
+ _eot = inference._eot
24
+ _eoa = inference._eoa
25
+ _pad_a = inference._pad_a
26
+ _input_a = inference._input_a
27
+ _answer_a = inference._answer_a
28
+
29
+
30
+ def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
31
+
32
+ with torch.no_grad():
33
+ mel = mel.unsqueeze(0).to(device)
34
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
35
+
36
+ audio_len = audio_feature.size(0)
37
+
38
+
39
+ input_ids = []
40
+ input_ids_item = [[] for i in range(8)]
41
+ for i in range(7):
42
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
43
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
44
+ input_ids_item[i] += [layershift(_answer_a,i)]
45
+
46
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
47
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
48
+
49
+ input_ids.append(input_ids_item)
50
+
51
+ input_ids_item = [[] for i in range(8)]
52
+ for i in range(7):
53
+ input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
54
+ input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
55
+
56
+ input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
57
+
58
+ input_ids_item = [torch.tensor(item) for item in input_ids_item]
59
+ input_ids.append(input_ids_item)
60
+
61
+ stacked_inputids = [[] for _ in range(8)]
62
+ for i in range(2):
63
+ for j in range(8):
64
+ stacked_inputids[j].append(input_ids[i][j])
65
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
66
+
67
+ return torch.stack([audio_feature,audio_feature]), stacked_inputids
68
+
69
+
70
+
71
+
72
+ def load_clip_model(ckpt_dir, device):
73
+ clip_model_path = ckpt_dir + "/ViT-B-32.pt"
74
+ if not os.path.exists(clip_model_path):
75
+ clip_model_path = "ViT-B/32"
76
+ clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
77
+ return clipmodel, clippreprocess
78
+
79
+
80
+ class OmniVisionInference(OmniInference):
81
+
82
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
83
+ self.device = device
84
+ if not os.path.exists(ckpt_dir):
85
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
86
+ download_model(ckpt_dir)
87
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
88
+ self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
89
+
90
+ def warm_up(self,
91
+ audio_sample='./data/samples/vision_qa_audio.wav',
92
+ image_sample='./data/samples/vision_qa_image.jpg'
93
+ ):
94
+ for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
95
+ save_path="./data/samples/vision_qa_output.wav",
96
+ warm_up=True):
97
+ pass
98
+
99
+ @torch.inference_mode()
100
+ def run_vision_AA_batch_stream(self, audio_path, image_path,
101
+ stream_stride=4,
102
+ max_returned_tokens=2048,
103
+ temperature=0.9,
104
+ top_k=1,
105
+ top_p=1.0,
106
+ eos_id_a=_eoa,
107
+ eos_id_t=_eot,
108
+ pad_id=_pad_t,
109
+ save_path=None,
110
+ warm_up=False
111
+ ):
112
+ with self.fabric.init_tensor():
113
+ self.model.set_kv_cache(batch_size=2)
114
+
115
+ model = self.model
116
+
117
+ mel, leng = load_audio(audio_path)
118
+ img = Image.open(image_path)
119
+
120
+ audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
121
+ ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
122
+ ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
123
+
124
+ ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
125
+ leng = [leng,leng]
126
+ task = ['ImageQA_A','ImageQA_AT']
127
+
128
+ T = input_ids[0].size(1)
129
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
130
+
131
+ if model.max_seq_length < max_returned_tokens - 1:
132
+ raise NotImplementedError(
133
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
134
+ )
135
+
136
+ list_output = [[] for i in range(8)]
137
+
138
+ tokens_A , token_T = next_token_image_batch(
139
+ model,
140
+ audio_feature.to(torch.float32).to(self.device),
141
+ ima_feature.to(torch.float32).to(self.device) ,
142
+ input_ids ,
143
+ whisper_lens = leng ,
144
+ task = task,
145
+ input_pos = torch.arange(0, T, device=self.device),
146
+ temperature=temperature,
147
+ top_k=top_k,
148
+ top_p=top_p
149
+ )
150
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
151
+ list_output[7].append(token_T.tolist()[0])
152
+
153
+ text_end = False
154
+ index = 1
155
+ nums_generate = stream_stride
156
+ begin_generate = False
157
+ current_index = 0
158
+ input_pos = torch.tensor([T], device=self.device)
159
+
160
+ model_input_ids = [[] for i in range(8)]
161
+ for i in range(7):
162
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
163
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
164
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
165
+ model_input_ids[i] = torch.stack(model_input_ids[i])
166
+
167
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
168
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
169
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
170
+
171
+ text_index = 0
172
+ is_text_end = False
173
+
174
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
175
+
176
+ tokens_A , token_T = next_token_image_batch(model, None , None ,
177
+ input_ids = model_input_ids,
178
+ whisper_lens= None,
179
+ task = None,
180
+ input_pos = input_pos,
181
+ temperature=temperature,
182
+ top_k=top_k,
183
+ top_p=top_p)
184
+
185
+ if text_end:
186
+ token_T = torch.tensor([_pad_t], device=self.device)
187
+
188
+ if tokens_A[-1] == eos_id_a:
189
+ break
190
+ if token_T == eos_id_t:
191
+ text_end = True
192
+
193
+ for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
194
+ list_output[7].append(token_T.tolist()[0])
195
+
196
+
197
+ if index == 7:
198
+ begin_generate = True
199
+
200
+ if begin_generate:
201
+ current_index += 1
202
+ if current_index == nums_generate:
203
+ current_index = 0
204
+ snac = get_snac(list_output,index,nums_generate)
205
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
206
+ if is_text_end:
207
+ text_stream = ""
208
+ else:
209
+ text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
210
+
211
+ yield (audio_stream, text_stream)
212
+
213
+ if warm_up:
214
+ break
215
+
216
+ input_pos = input_pos.add_(1)
217
+ model_input_ids = [[] for i in range(8)]
218
+ for i in range(7):
219
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
220
+ model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
221
+ model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
222
+ model_input_ids[i] = torch.stack(model_input_ids[i])
223
+
224
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
225
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
226
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
227
+
228
+ index += 1
229
+
230
+ text_tokens = list_output[-1]
231
+ if text_vocabsize in text_tokens:
232
+ text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
233
+ res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
234
+ print(f"text output: {res_text}")
235
+
236
+ if save_path is not None:
237
+ audiolist = reconscruct_snac(list_output)
238
+ audio = reconstruct_tensors(audiolist)
239
+ with torch.inference_mode():
240
+ audio_hat = self.snacmodel.decode(audio)
241
+ sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
242
+
243
+ model.clear_kv_cache()
244
+
245
+
246
+ def test_vision_infer():
247
+ client = OmniVisionInference()
248
+ client.warm_up()
249
+ input_audio_path = './data/samples/vision_qa_audio.wav'
250
+ input_image_path = './data/samples/vision_qa_image.jpg'
251
+
252
+ res_text = ""
253
+ for audio_stream, text_stream in client.run_vision_AA_batch_stream(
254
+ input_audio_path,
255
+ input_image_path,
256
+ save_path="./vision_qa_output.wav"
257
+ ):
258
+ res_text += text_stream
259
+ print(f"text_output: {res_text}")
260
+
261
+
262
+ if __name__ == "__main__":
263
+ test_vision_infer()
litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
litgpt/config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+ vision_adapter_dim: int = 512
81
+
82
+ post_adapter: bool = False
83
+ post_adapter_layers: int = 6
84
+ asr_adapter: str = "llamamlp"
85
+
86
+ def __post_init__(self):
87
+ if not self.name:
88
+ self.name = self.hf_config.get("name", self.name)
89
+
90
+ if self.head_size is None:
91
+ assert self.n_embd % self.n_head == 0
92
+ self.head_size = self.n_embd // self.n_head
93
+
94
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
95
+ if self.padded_vocab_size is None:
96
+ self.padded_vocab_size = find_multiple(
97
+ self.vocab_size, self.padding_multiple
98
+ )
99
+ else:
100
+ # vocab size shouldn't be larger than padded vocab size
101
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
102
+
103
+ # compute the number of query groups
104
+ if self.n_query_groups is not None:
105
+ assert self.n_head % self.n_query_groups == 0
106
+ else:
107
+ self.n_query_groups = self.n_head
108
+
109
+ # compute the intermediate size for MLP if not set
110
+ if self.intermediate_size is None:
111
+ if self.mlp_class_name == "LLaMAMLP":
112
+ raise ValueError(
113
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
114
+ )
115
+ self.intermediate_size = 4 * self.n_embd
116
+
117
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
118
+
119
+ if self.add_qkv_bias is None:
120
+ self.add_qkv_bias = self.bias
121
+
122
+ @classmethod
123
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
124
+ if name not in name_to_config:
125
+ # search through all `config['hf_config']['name']`
126
+ try:
127
+ conf_dict = next(
128
+ config
129
+ for config in configs
130
+ if name == config["hf_config"]["name"]
131
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
132
+ == name
133
+ )
134
+ except StopIteration:
135
+ raise ValueError(f"{name!r} is not a supported config name")
136
+ else:
137
+ conf_dict = name_to_config[name]
138
+
139
+ conf_dict = conf_dict.copy()
140
+ conf_dict.update(kwargs)
141
+ return cls(**conf_dict)
142
+
143
+ @classmethod
144
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
145
+ with open(path, encoding="utf-8") as fp:
146
+ file_kwargs = yaml.safe_load(fp)
147
+ if file_kwargs is None:
148
+ raise ValueError(f"{path} is empty which is likely unexpected.")
149
+ file_kwargs.update(kwargs)
150
+ return cls(**file_kwargs)
151
+
152
+ @classmethod
153
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
154
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
155
+ if (config_path := path / "model_config.yaml").is_file():
156
+ return cls.from_file(config_path, **kwargs)
157
+ if (model_name := path.name) in name_to_config:
158
+ return cls.from_name(model_name, **kwargs)
159
+ raise FileNotFoundError(
160
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
161
+ )
162
+
163
+ @property
164
+ def mlp_class(self) -> Type:
165
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
166
+ return getattr(litgpt.model, self.mlp_class_name)
167
+
168
+ @property
169
+ def norm_class(self) -> Type:
170
+ # `self.norm_class_name` cannot be the type to keep the config serializable
171
+ if self.norm_class_name == "RMSNorm":
172
+ from functools import partial
173
+
174
+ from litgpt.model import RMSNorm
175
+
176
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
177
+ return getattr(torch.nn, self.norm_class_name)
178
+
179
+
180
+ configs = []
181
+ name_to_config = {config["name"]: config for config in configs}
litgpt/generate/__init__.py ADDED
File without changes
litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(None, x, None, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+ next_audio_tokens.append(next_a)
95
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
96
+ return next_audio_tokens, next_t
97
+
98
+
99
+ def next_token_A1T2(
100
+ model: GPT,
101
+ audio_features: torch.tensor,
102
+ input_ids: list,
103
+ whisper_lens: int,
104
+ task: list,
105
+ input_pos: torch.Tensor,
106
+ **kwargs: Any,
107
+ ) -> torch.Tensor:
108
+ input_pos = input_pos.to(model.device)
109
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
110
+ logits_a, logit_t = model(
111
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
112
+ )
113
+
114
+ next_audio_tokens = []
115
+ for logit_a in logits_a:
116
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
117
+ next_audio_tokens.append(next_a)
118
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
119
+ return next_audio_tokens, next_t
120
+
121
+
122
+ def next_token_A1T1(
123
+ model: GPT,
124
+ audio_features: torch.tensor,
125
+ input_ids: list,
126
+ whisper_lens: int,
127
+ task: list,
128
+ input_pos: torch.Tensor,
129
+ **kwargs: Any,
130
+ ) -> torch.Tensor:
131
+ input_pos = input_pos.to(model.device)
132
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
133
+ logits_a, logit_t = model(
134
+ audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
135
+ )
136
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
137
+ return next_t
138
+
139
+
140
+ def next_token_image_batch(model: GPT,
141
+ audio_features: torch.tensor,
142
+ clip_features: torch.tensor,
143
+ input_ids: list,
144
+ whisper_lens: int,
145
+ task: list,
146
+ input_pos: torch.Tensor,
147
+ **kwargs: Any) -> torch.Tensor:
148
+ input_pos = input_pos.to(model.device)
149
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
150
+ logits_a,logit_t = model(audio_features, input_ids, clip_features,
151
+ input_pos, whisper_lens=whisper_lens, task=task)
152
+
153
+ for i in range(7):
154
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
155
+ logit_t = logit_t[1].unsqueeze(0)
156
+
157
+ next_audio_tokens = []
158
+ for logit_a in logits_a:
159
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
160
+ next_audio_tokens.append(next_a)
161
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
162
+ return next_audio_tokens, next_t
163
+
164
+
165
+ # torch._dynamo.config.automatic_dynamic_shapes = True
166
+ # torch._inductor.config.triton.unique_kernel_names = True
167
+ # torch._inductor.config.coordinate_descent_tuning = True
168
+ # next_token = torch.compile(next_token, mode="reduce-overhead")
169
+
170
+
171
+ @torch.inference_mode()
172
+ def generate(
173
+ model: GPT,
174
+ input_ids: list,
175
+ max_returned_tokens: int,
176
+ *,
177
+ temperature: float = 1.0,
178
+ top_k: Optional[int] = None,
179
+ top_p: float = 1.0,
180
+ eos_id_a: Optional[int] = None,
181
+ eos_id_t: Optional[int] = None,
182
+ pad_id: Optional[int] = None,
183
+ shift: Optional[int] = None,
184
+ include_prompt: bool = True,
185
+ generate_text=False,
186
+ ) -> torch.Tensor:
187
+ # print("eos_id_a:", eos_id_a)
188
+ # print("eos_id_t:", eos_id_t)
189
+ # print("pad_id:", pad_id)
190
+ """
191
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
192
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
193
+
194
+ Args:
195
+ model: The model to use.
196
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
197
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
198
+ temperature: Scales the predicted logits by 1 / temperature.
199
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
200
+ top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
201
+ In top-p sampling, the next token is sampled from the highest probability tokens
202
+ whose cumulative probability exceeds the threshold `top_p`. When specified,
203
+ it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
204
+ to sampling the most probable token, while `top_p=1` samples from the whole distribution.
205
+ It can be used in conjunction with `top_k` and `temperature` with the following order
206
+ of application:
207
+
208
+ 1. `top_k` sampling
209
+ 2. `temperature` scaling
210
+ 3. `top_p` sampling
211
+
212
+ For more details, see https://arxiv.org/abs/1904.09751
213
+ or https://huyenchip.com/2024/01/16/sampling.html#top_p
214
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
215
+ include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
216
+ """
217
+ T = input_ids[0].size(0)
218
+ device = input_ids[0].device
219
+ assert max_returned_tokens > T
220
+ if model.max_seq_length < max_returned_tokens - 1:
221
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
222
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
223
+ # not support it to avoid negatively impacting the overall speed
224
+ raise NotImplementedError(
225
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
226
+ )
227
+
228
+ for input_id in input_ids:
229
+ input_id = [input_id]
230
+ (
231
+ tokens_A1,
232
+ tokens_A2,
233
+ tokens_A3,
234
+ tokens_A4,
235
+ tokens_A5,
236
+ tokens_A6,
237
+ tokens_A7,
238
+ tokens_T,
239
+ ) = input_ids
240
+
241
+ tokens_A1_output = [tokens_A1]
242
+ tokens_A2_output = [tokens_A2]
243
+ tokens_A3_output = [tokens_A3]
244
+ tokens_A4_output = [tokens_A4]
245
+ tokens_A5_output = [tokens_A5]
246
+ tokens_A6_output = [tokens_A6]
247
+ tokens_A7_output = [tokens_A7]
248
+ tokens_T_output = [tokens_T]
249
+
250
+ list_output = [
251
+ tokens_A1_output,
252
+ tokens_A2_output,
253
+ tokens_A3_output,
254
+ tokens_A4_output,
255
+ tokens_A5_output,
256
+ tokens_A6_output,
257
+ tokens_A7_output,
258
+ tokens_T_output,
259
+ ]
260
+
261
+ input_pos = torch.tensor([T], device=device)
262
+ model_input_ids = [
263
+ tokens_A1.view(1, -1),
264
+ tokens_A2.view(1, -1),
265
+ tokens_A3.view(1, -1),
266
+ tokens_A4.view(1, -1),
267
+ tokens_A5.view(1, -1),
268
+ tokens_A6.view(1, -1),
269
+ tokens_A7.view(1, -1),
270
+ tokens_T.view(1, -1),
271
+ ]
272
+
273
+ tokens_A, token_T = next_token(
274
+ model,
275
+ torch.arange(0, T, device=device),
276
+ model_input_ids,
277
+ temperature=temperature,
278
+ top_k=top_k,
279
+ top_p=top_p,
280
+ )
281
+ for i in range(7):
282
+ list_output[i].append(tokens_A[i].clone())
283
+ list_output[7].append(token_T.clone())
284
+
285
+ # prepare the input for the next iteration
286
+ for i in range(7):
287
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
288
+ token_T = token_T.clone()
289
+
290
+ text_end = False
291
+ max_returned_tokens = 1000
292
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
293
+ model_input_ids = [
294
+ token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
295
+ ] + [token_T.view(1, -1).to(torch.int32)]
296
+ tokens_A, token_T = next_token(
297
+ model,
298
+ input_pos,
299
+ model_input_ids,
300
+ temperature=temperature,
301
+ top_k=top_k,
302
+ top_p=top_p,
303
+ )
304
+ if text_end:
305
+ token_T = torch.tensor([pad_id], device=device)
306
+
307
+ for i in range(7):
308
+ list_output[i].append(tokens_A[i].clone())
309
+ list_output[7].append(token_T.clone())
310
+
311
+ if tokens_A[-1] == eos_id_a:
312
+ break
313
+ if token_T == eos_id_t:
314
+ if generate_text:
315
+ break
316
+ text_end = True
317
+
318
+ for i in range(7):
319
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
320
+ token_T = token_T.clone()
321
+ input_pos = input_pos.add_(1)
322
+
323
+ for i in range(len(list_output)):
324
+ list_output[i] = torch.cat(list_output[i])
325
+ return list_output
326
+
327
+
328
+ @torch.inference_mode()
329
+ def generate_TA_BATCH(
330
+ model: GPT,
331
+ audio_features: torch.Tensor,
332
+ input_ids: list,
333
+ leng,
334
+ task,
335
+ max_returned_tokens: int = 1000,
336
+ *,
337
+ temperature: float = 1.0,
338
+ top_k: Optional[int] = None,
339
+ top_p: float = 1.0,
340
+ eos_id_a: Optional[int] = None,
341
+ eos_id_t: Optional[int] = None,
342
+ pad_id_t: Optional[int] = None,
343
+ shift: Optional[int] = None,
344
+ include_prompt: bool = True,
345
+ generate_text=False,
346
+ ) -> torch.Tensor:
347
+
348
+ T = input_ids[0].size(1)
349
+ device = input_ids[0].device
350
+ assert max_returned_tokens > T
351
+ if model.max_seq_length < max_returned_tokens - 1:
352
+ raise NotImplementedError(
353
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
354
+ )
355
+
356
+ input_pos = torch.tensor([T], device=device)
357
+ model_input_ids = input_ids
358
+
359
+ list_output = [[] for i in range(8)]
360
+
361
+ tokens_A, token_T = next_token_image_batch(
362
+ model,
363
+ audio_features.to(torch.float32).to(model.device),
364
+ None,
365
+ input_ids,
366
+ [T - 3, T - 3],
367
+ ["A1T2", "A1T2"],
368
+ input_pos=torch.arange(0, T, device=device),
369
+ temperature=temperature,
370
+ top_k=top_k,
371
+ top_p=top_p,
372
+ )
373
+
374
+ for i in range(7):
375
+ list_output[i].append(tokens_A[i].tolist()[0])
376
+ list_output[7].append(token_T.tolist()[0])
377
+
378
+ model_input_ids = [[] for i in range(8)]
379
+ for i in range(7):
380
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
381
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
382
+ model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
383
+ model_input_ids[i] = torch.stack(model_input_ids[i])
384
+
385
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
386
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
387
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
388
+
389
+ text_end = False
390
+
391
+ for _ in range(2, max_returned_tokens - T + 1):
392
+ tokens_A, token_T = next_token_image_batch(
393
+ model,
394
+ None,
395
+ None,
396
+ model_input_ids,
397
+ None,
398
+ None,
399
+ input_pos=input_pos,
400
+ temperature=temperature,
401
+ top_k=top_k,
402
+ top_p=top_p,
403
+ )
404
+
405
+ if text_end:
406
+ token_T = torch.tensor([pad_id_t], device=device)
407
+
408
+ if tokens_A[-1] == eos_id_a:
409
+ break
410
+ if token_T == eos_id_t:
411
+ text_end = True
412
+
413
+ for i in range(7):
414
+ list_output[i].append(tokens_A[i].tolist()[0])
415
+ list_output[7].append(token_T.tolist()[0])
416
+
417
+ model_input_ids = [[] for i in range(8)]
418
+ for i in range(7):
419
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
420
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
421
+ model_input_ids[i].append(
422
+ torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
423
+ )
424
+ model_input_ids[i] = torch.stack(model_input_ids[i])
425
+
426
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
427
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
428
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
429
+
430
+ input_pos = input_pos.add_(1)
431
+
432
+ return list_output
433
+
434
+
435
+ @torch.inference_mode()
436
+ def generate_TT(
437
+ model: GPT,
438
+ audio_features: torch.Tensor,
439
+ input_ids: list,
440
+ leng,
441
+ task,
442
+ max_returned_tokens: int = 2048,
443
+ *,
444
+ temperature: float = 1.0,
445
+ top_k: Optional[int] = None,
446
+ top_p: float = 1.0,
447
+ eos_id_a: Optional[int] = None,
448
+ eos_id_t: Optional[int] = None,
449
+ pad_id_t: Optional[int] = None,
450
+ shift: Optional[int] = None,
451
+ include_prompt: bool = True,
452
+ generate_text=False,
453
+ ) -> torch.Tensor:
454
+
455
+ T = input_ids[0].size(1)
456
+ device = input_ids[0].device
457
+
458
+ output = []
459
+ token_T = next_token_A1T1(
460
+ model,
461
+ None,
462
+ input_ids,
463
+ None,
464
+ None,
465
+ input_pos=torch.arange(0, T, device=device),
466
+ temperature=temperature,
467
+ top_k=top_k,
468
+ top_p=top_p,
469
+ )
470
+
471
+ output.append(token_T.clone().tolist()[0])
472
+ input_pos = torch.tensor([T], device=device)
473
+
474
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
475
+ model_input_ids = []
476
+ for i in range(7):
477
+ model_input_ids.append(
478
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
479
+ .view(1, -1)
480
+ .to(torch.int32)
481
+ .to(device)
482
+ )
483
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
484
+ token_T = next_token_A1T1(
485
+ model,
486
+ None,
487
+ model_input_ids,
488
+ None,
489
+ None,
490
+ input_pos=input_pos,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ )
495
+ if token_T == eos_id_t:
496
+ break
497
+ output.append(token_T.clone().tolist()[0])
498
+ input_pos = input_pos.add_(1)
499
+ return output
500
+
501
+
502
+ @torch.inference_mode()
503
+ def generate_AT(
504
+ model: GPT,
505
+ audio_features: torch.Tensor,
506
+ input_ids: list,
507
+ leng,
508
+ task,
509
+ max_returned_tokens: int = 2048,
510
+ *,
511
+ temperature: float = 1.0,
512
+ top_k: Optional[int] = None,
513
+ top_p: float = 1.0,
514
+ eos_id_a: Optional[int] = None,
515
+ eos_id_t: Optional[int] = None,
516
+ pad_id_t: Optional[int] = None,
517
+ shift: Optional[int] = None,
518
+ include_prompt: bool = True,
519
+ generate_text=False,
520
+ ) -> torch.Tensor:
521
+
522
+ T = input_ids[0].size(1)
523
+ device = input_ids[0].device
524
+
525
+ output = []
526
+ token_T = next_token_A1T1(
527
+ model,
528
+ audio_features.to(torch.float32).to(model.device),
529
+ input_ids,
530
+ [T - 3],
531
+ ["AT"],
532
+ input_pos=torch.arange(0, T, device=device),
533
+ temperature=temperature,
534
+ top_k=top_k,
535
+ top_p=top_p,
536
+ )
537
+ output.append(token_T.clone().tolist()[0])
538
+ input_pos = torch.tensor([T], device=device)
539
+ text_end = False
540
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
541
+ model_input_ids = []
542
+ for i in range(7):
543
+ model_input_ids.append(
544
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
545
+ .view(1, -1)
546
+ .to(torch.int32)
547
+ .to(device)
548
+ )
549
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
550
+ token_T = next_token_A1T1(
551
+ model,
552
+ None,
553
+ model_input_ids,
554
+ None,
555
+ None,
556
+ input_pos=input_pos,
557
+ temperature=temperature,
558
+ top_k=top_k,
559
+ top_p=top_p,
560
+ )
561
+ if token_T == eos_id_t:
562
+ break
563
+ output.append(token_T.clone().tolist()[0])
564
+ input_pos = input_pos.add_(1)
565
+ return output
566
+
567
+
568
+ @torch.inference_mode()
569
+ def generate_TA(
570
+ model: GPT,
571
+ audio_features: torch.Tensor,
572
+ input_ids: list,
573
+ leng,
574
+ task,
575
+ max_returned_tokens: int = 2048,
576
+ *,
577
+ temperature: float = 1.0,
578
+ top_k: Optional[int] = None,
579
+ top_p: float = 1.0,
580
+ eos_id_a: Optional[int] = None,
581
+ eos_id_t: Optional[int] = None,
582
+ pad_id_t: Optional[int] = None,
583
+ shift: Optional[int] = None,
584
+ include_prompt: bool = True,
585
+ generate_text=False,
586
+ ) -> torch.Tensor:
587
+
588
+ T = input_ids[0].size(1)
589
+ device = input_ids[0].device
590
+
591
+ output = [[] for _ in range(8)]
592
+ tokens_A, token_T = next_token_A1T2(
593
+ model,
594
+ None,
595
+ input_ids,
596
+ None,
597
+ None,
598
+ input_pos=torch.arange(0, T, device=device),
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ )
603
+ for i in range(7):
604
+ output[i].append(tokens_A[i].clone().tolist()[0])
605
+ output[7].append(token_T.clone().tolist()[0])
606
+
607
+ input_pos = torch.tensor([T], device=device)
608
+ text_end = False
609
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
610
+
611
+ model_input_ids = []
612
+ for i in range(7):
613
+ model_input_ids.append(
614
+ layershift(tokens_A[i].clone(), i)
615
+ .view(1, -1)
616
+ .to(torch.int32)
617
+ .to(device)
618
+ )
619
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
620
+
621
+ tokens_A, token_T = next_token_A1T2(
622
+ model,
623
+ None,
624
+ model_input_ids,
625
+ None,
626
+ None,
627
+ input_pos=input_pos,
628
+ temperature=temperature,
629
+ top_k=top_k,
630
+ top_p=top_p,
631
+ )
632
+
633
+ if text_end:
634
+ token_T = torch.tensor([pad_id_t], device=device)
635
+
636
+ if tokens_A[-1] == eos_id_a:
637
+ break
638
+
639
+ if token_T == eos_id_t:
640
+ text_end = True
641
+
642
+ for i in range(7):
643
+ output[i].append(tokens_A[i].clone().tolist()[0])
644
+ output[7].append(token_T.clone().tolist()[0])
645
+ input_pos = input_pos.add_(1)
646
+
647
+ return output
648
+
649
+
650
+ @torch.inference_mode()
651
+ def generate_AA(
652
+ model: GPT,
653
+ audio_features: torch.Tensor,
654
+ input_ids: list,
655
+ leng,
656
+ task,
657
+ max_returned_tokens: int = 2048,
658
+ *,
659
+ temperature: float = 1.0,
660
+ top_k: Optional[int] = None,
661
+ top_p: float = 1.0,
662
+ eos_id_a: Optional[int] = None,
663
+ eos_id_t: Optional[int] = None,
664
+ pad_id_t: Optional[int] = None,
665
+ shift: Optional[int] = None,
666
+ include_prompt: bool = True,
667
+ generate_text=False,
668
+ ) -> torch.Tensor:
669
+
670
+ T = input_ids[0].size(1)
671
+ device = input_ids[0].device
672
+
673
+ output = [[] for _ in range(8)]
674
+ tokens_A, token_T = next_token_A1T2(
675
+ model,
676
+ audio_features.to(torch.float32).to(model.device),
677
+ input_ids,
678
+ [T - 3],
679
+ ["A1T2"],
680
+ input_pos=torch.arange(0, T, device=device),
681
+ temperature=temperature,
682
+ top_k=top_k,
683
+ top_p=top_p,
684
+ )
685
+ for i in range(7):
686
+ output[i].append(tokens_A[i].clone().tolist()[0])
687
+ output[7].append(token_T.clone().tolist()[0])
688
+
689
+ input_pos = torch.tensor([T], device=device)
690
+
691
+ text_end = False
692
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
693
+
694
+ model_input_ids = []
695
+ for i in range(7):
696
+ model_input_ids.append(
697
+ layershift(tokens_A[i].clone(), i)
698
+ .view(1, -1)
699
+ .to(torch.int32)
700
+ .to(device)
701
+ )
702
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
703
+
704
+ tokens_A, token_T = next_token_A1T2(
705
+ model,
706
+ None,
707
+ model_input_ids,
708
+ None,
709
+ None,
710
+ input_pos=input_pos,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ )
715
+
716
+ if text_end:
717
+ token_T = torch.tensor([pad_id_t], device=device)
718
+
719
+ if tokens_A[-1] == eos_id_a:
720
+ break
721
+ if token_T == eos_id_t:
722
+ # print("text_end")
723
+ text_end = True
724
+
725
+ for i in range(7):
726
+ output[i].append(tokens_A[i].clone().tolist()[0])
727
+ output[7].append(token_T.clone().tolist()[0])
728
+ input_pos = input_pos.add_(1)
729
+
730
+ return output
731
+
732
+
733
+ @torch.inference_mode()
734
+ def generate_ASR(
735
+ model: GPT,
736
+ audio_features: torch.Tensor,
737
+ input_ids: list,
738
+ leng,
739
+ task,
740
+ max_returned_tokens: int = 1200,
741
+ *,
742
+ temperature: float = 1.0,
743
+ top_k: Optional[int] = None,
744
+ top_p: float = 1.0,
745
+ eos_id_a: Optional[int] = None,
746
+ eos_id_t: Optional[int] = None,
747
+ pad_id_t: Optional[int] = None,
748
+ shift: Optional[int] = None,
749
+ include_prompt: bool = True,
750
+ generate_text=False,
751
+ ) -> torch.Tensor:
752
+
753
+ T = input_ids[0].size(1)
754
+ device = input_ids[0].device
755
+ output = []
756
+ token_T = next_token_A1T1(
757
+ model,
758
+ audio_features.to(torch.float32).to(model.device),
759
+ input_ids,
760
+ [T - 3],
761
+ ["asr"],
762
+ input_pos=torch.arange(0, T, device=device),
763
+ temperature=temperature,
764
+ top_k=top_k,
765
+ top_p=top_p,
766
+ )
767
+ output.append(token_T.clone().tolist()[0])
768
+ input_pos = torch.tensor([T], device=device)
769
+ text_end = False
770
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
771
+ model_input_ids = []
772
+ for i in range(7):
773
+ model_input_ids.append(
774
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
775
+ .view(1, -1)
776
+ .to(torch.int32)
777
+ .to(device)
778
+ )
779
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
780
+ token_T = next_token_A1T1(
781
+ model,
782
+ None,
783
+ model_input_ids,
784
+ None,
785
+ None,
786
+ input_pos=input_pos,
787
+ temperature=temperature,
788
+ top_k=top_k,
789
+ top_p=top_p,
790
+ )
791
+ if token_T == eos_id_t:
792
+ break
793
+ output.append(token_T.clone().tolist()[0])
794
+ input_pos = input_pos.add_(1)
795
+ return output
litgpt/model.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Full definition of a decoder-only transformer-based language model, all of it in this single file.
4
+
5
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
6
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing_extensions import Self
15
+ from litgpt.config import Config
16
+
17
+
18
+
19
+ class GPT(nn.Module):
20
+ def __init__(self, config: Config) -> None:
21
+ super().__init__()
22
+ assert config.padded_vocab_size is not None
23
+ self.config = config
24
+ if self.config.asr_adapter == "mlp":
25
+ print("Using MLP adapter for ASR feature")
26
+ self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
27
+ elif self.config.asr_adapter == "llamamlp":
28
+ print("using LLAMA MLP adapter for ASR feature")
29
+ self.whisper_adapter = whisperMLP(config=config)
30
+ else:
31
+ raise ValueError("asr_adapter should be mlp or llamamlp")
32
+ self.lm_head = nn.Linear(
33
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
34
+ )
35
+
36
+ self.vision_adapter = visionMLP(config = config)
37
+ if config.post_adapter:
38
+ self.transformer = nn.ModuleDict(
39
+ dict(
40
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
41
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
42
+ post_adapter=nn.ModuleList(
43
+ Block(config) for _ in range(config.post_adapter_layers)
44
+ ),
45
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
46
+ post_adapter_audio_ln=config.norm_class(
47
+ config.n_embd, eps=config.norm_eps
48
+ ),
49
+ post_adapter_audio_lm_head=nn.Linear(
50
+ config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
51
+ ),
52
+ )
53
+ )
54
+ else:
55
+ self.transformer = nn.ModuleDict(
56
+ dict(
57
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
58
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
59
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
60
+ )
61
+ )
62
+ self.max_seq_length = self.config.block_size
63
+ self.mask_cache: Optional[torch.Tensor] = None
64
+ if config.tie_word_embeddings:
65
+ self.lm_head.weight = self.transformer.wte.weight
66
+
67
+ @property
68
+ def max_seq_length(self) -> int:
69
+ return self._max_seq_length
70
+
71
+ @max_seq_length.setter
72
+ def max_seq_length(self, value: int) -> None:
73
+ """
74
+ When doing inference, the sequences used might be shorter than the model's context length.
75
+ This allows setting a smaller number to avoid allocating unused memory
76
+ """
77
+ if value > self.config.block_size:
78
+ raise ValueError(
79
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
80
+ )
81
+ self._max_seq_length = value
82
+ if not hasattr(self, "cos"):
83
+ # first call
84
+ cos, sin = self.rope_cache()
85
+ self.register_buffer("cos", cos, persistent=False)
86
+ self.register_buffer("sin", sin, persistent=False)
87
+ # override
88
+ elif value != self.cos.size(0):
89
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
90
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
91
+ # if the kv cache is expected
92
+
93
+ def reset_parameters(self) -> None:
94
+ # Trigger resetting the rope-cache
95
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
96
+
97
+ def _init_weights(self, module: nn.Module) -> None:
98
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
99
+ if isinstance(module, nn.Linear):
100
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
101
+ if module.bias is not None:
102
+ torch.nn.init.zeros_(module.bias)
103
+ elif isinstance(module, nn.Embedding):
104
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
105
+
106
+ def concat_feat(self, audio_feature, clip_feature, input_ids, T, task):
107
+
108
+ for j in range(len(T)):
109
+ if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT':
110
+ for i in range(7):
111
+ input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone()
112
+ assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature"
113
+
114
+ elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT':
115
+ print("concat ImageQA_A feature")
116
+ for i in range(7):
117
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
118
+
119
+ input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone()
120
+
121
+ elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP':
122
+ for i in range(7):
123
+ input_ids[i][j,1:51,:] = clip_feature[j].clone()
124
+
125
+ return input_ids
126
+
127
+ def forward(
128
+ self,
129
+ audio_features: torch.Tensor,
130
+ input_ids: torch.Tensor,
131
+ clip_features: torch.Tensor,
132
+ input_pos: Optional[torch.Tensor] = None,
133
+ whisper_lens: Optional[list] = None,
134
+ task: Optional[str] = None,
135
+ ) -> torch.Tensor:
136
+
137
+ show = False
138
+ T = input_ids[0].size(1)
139
+ if self.max_seq_length < T:
140
+ raise ValueError(
141
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
142
+ )
143
+
144
+ if input_pos is not None: # use the kv cache
145
+ cos = self.cos.index_select(0, input_pos)
146
+ sin = self.sin.index_select(0, input_pos)
147
+ if self.mask_cache is None:
148
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
149
+ mask = self.mask_cache.index_select(2, input_pos)
150
+ else:
151
+ cos = self.cos[:T]
152
+ sin = self.sin[:T]
153
+ mask = None
154
+
155
+ if audio_features is not None:
156
+ # get whisper feature
157
+ x_a = self.whisper_adapter(audio_features)
158
+ if clip_features is not None:
159
+ x_v = self.vision_adapter(clip_features)
160
+ else:
161
+ x_v = None
162
+ # get input_ids embedding
163
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
164
+
165
+ x0 = self.transformer.wte(x0)
166
+ x1 = self.transformer.wte(x1)
167
+ x2 = self.transformer.wte(x2)
168
+ x3 = self.transformer.wte(x3)
169
+ x4 = self.transformer.wte(x4)
170
+ x5 = self.transformer.wte(x5)
171
+ x6 = self.transformer.wte(x6)
172
+ x7 = self.transformer.wte(x7)
173
+
174
+ # concat whisper feature
175
+ input_emb = self.concat_feat(
176
+ x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
177
+ )
178
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
179
+
180
+ else:
181
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
182
+
183
+ x0 = self.transformer.wte(x0)
184
+ x1 = self.transformer.wte(x1)
185
+ x2 = self.transformer.wte(x2)
186
+ x3 = self.transformer.wte(x3)
187
+ x4 = self.transformer.wte(x4)
188
+ x5 = self.transformer.wte(x5)
189
+ x6 = self.transformer.wte(x6)
190
+ x7 = self.transformer.wte(x7)
191
+
192
+ x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
193
+
194
+ if self.config.scale_embeddings:
195
+ x = x * (self.config.n_embd**0.5)
196
+
197
+ for block in self.transformer.h:
198
+ x = block(x, cos, sin, mask, input_pos)
199
+
200
+
201
+ text_vocab_size = self.config.text_vocab_size
202
+ audio_vocab_size = self.config.audio_vocab_size
203
+
204
+ x_ori = x
205
+ x_ori = self.transformer.ln_f(x_ori)
206
+ x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
207
+ xt = x_ori[..., :text_vocab_size]
208
+
209
+ if self.config.post_adapter:
210
+ for block in self.transformer.post_adapter:
211
+ x = block(x, cos, sin, mask, input_pos)
212
+ x = self.transformer.post_adapter_audio_ln(x)
213
+ x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
214
+ xa = []
215
+ for i in range(7):
216
+ xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
217
+ else:
218
+ xa = []
219
+ for i in range(7):
220
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
221
+
222
+ return xa, xt
223
+
224
+ @classmethod
225
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
226
+ return cls(Config.from_name(name, **kwargs))
227
+
228
+ def rope_cache(
229
+ self, device: Optional[torch.device] = None
230
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
231
+ return build_rope_cache(
232
+ seq_len=self.max_seq_length,
233
+ n_elem=self.config.rope_n_elem,
234
+ device=device,
235
+ condense_ratio=self.config.rope_condense_ratio,
236
+ base=self.config.rope_base,
237
+ )
238
+
239
+ def set_kv_cache(
240
+ self,
241
+ batch_size: int,
242
+ rope_cache_length: Optional[int] = None,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ) -> None:
246
+ if rope_cache_length is None:
247
+ rope_cache_length = self.cos.size(-1)
248
+ max_seq_length = self.max_seq_length
249
+
250
+ # initialize the kv cache for all blocks
251
+ for block in self.transformer.h:
252
+ block.attn.kv_cache = block.attn.build_kv_cache(
253
+ batch_size, max_seq_length, rope_cache_length, device, dtype
254
+ )
255
+ if self.config.post_adapter:
256
+ for block in self.transformer.post_adapter:
257
+ block.attn.kv_cache = block.attn.build_kv_cache(
258
+ batch_size, max_seq_length, rope_cache_length, device, dtype
259
+ )
260
+
261
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
262
+ # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
263
+ # for the kv-cache support (only during inference), we only create it in that situation
264
+ self.mask_cache = build_mask_cache(max_seq_length, device)
265
+
266
+ def clear_kv_cache(self) -> None:
267
+ self.mask_cache = None
268
+ for block in self.transformer.h:
269
+ block.attn.kv_cache = None
270
+
271
+
272
+ class visionMLP(nn.Module):
273
+ def __init__(self, config: Config) -> None:
274
+ super().__init__()
275
+ vision_adapter_dim = config.vision_adapter_dim
276
+ self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
277
+ self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
278
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
279
+
280
+ self.config = config
281
+
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ x_fc_1 = self.fc_1(x)
284
+ x_fc_2 = self.fc_2(x)
285
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
286
+ return self.proj(x)
287
+
288
+
289
+ class Block(nn.Module):
290
+
291
+ def __init__(self, config: Config) -> None:
292
+ super().__init__()
293
+ if not config.parallel_residual and config.shared_attention_norm:
294
+ raise NotImplementedError(
295
+ "No checkpoint amongst the ones we support uses this configuration"
296
+ " (non-parallel residual and shared attention norm)."
297
+ )
298
+
299
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
300
+ self.attn = CausalSelfAttention(config)
301
+ self.norm_2 = (
302
+ None
303
+ if config.shared_attention_norm
304
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
305
+ )
306
+ self.mlp = config.mlp_class(config)
307
+
308
+ self.config = config
309
+
310
+ def forward(
311
+ self,
312
+ x: torch.Tensor,
313
+ cos: torch.Tensor,
314
+ sin: torch.Tensor,
315
+ mask: Optional[torch.Tensor] = None,
316
+ input_pos: Optional[torch.Tensor] = None,
317
+ ) -> torch.Tensor:
318
+ """
319
+ Non-parallel residual Parallel residual
320
+ ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
321
+ │ ↓ │ ↓ ↓ the output from `norm_1` is reused
322
+ │ norm_1 │ norm_1 ───► norm_2
323
+ │ ↓ │ ↓ ↓
324
+ │ attn │ attn mlp
325
+ │ ↓ │ ↓ │
326
+ ┌─ └► + └► + ◄───────────┘
327
+ │ norm_2
328
+ │ ↓
329
+ │ mlp
330
+ │ ↓
331
+ └───► +
332
+ """
333
+
334
+ x_normed = self.norm_1(x)
335
+ attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
336
+
337
+ if self.config.parallel_residual:
338
+ x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
339
+ x = self.mlp(x_normed) + attention_output + x
340
+ else:
341
+ x = attention_output + x
342
+ x = self.mlp(self.norm_2(x)) + x
343
+ return x
344
+
345
+
346
+ class CausalSelfAttention(nn.Module):
347
+ def __init__(self, config: Config) -> None:
348
+ super().__init__()
349
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
350
+ # key, query, value projections for all heads, but in a batch
351
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
352
+ # output projection
353
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
354
+ self.proj = nn.Linear(
355
+ config.head_size * config.n_head, config.n_embd, bias=config.bias
356
+ )
357
+ # disabled by default
358
+ self.kv_cache: Optional[KVCache] = None
359
+
360
+ self.config = config
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ cos: torch.Tensor,
366
+ sin: torch.Tensor,
367
+ mask: Optional[torch.Tensor] = None,
368
+ input_pos: Optional[torch.Tensor] = None,
369
+ ) -> torch.Tensor:
370
+ B, T, C = (
371
+ x.size()
372
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
373
+
374
+ qkv = self.attn(x)
375
+
376
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
377
+ q_per_kv = self.config.n_head // self.config.n_query_groups
378
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
379
+ qkv = qkv.view(
380
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
381
+ )
382
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
383
+
384
+ # split batched computation into three
385
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
386
+
387
+ # maybe repeat k and v if for the non multi-head attention cases
388
+ # training: flash attention requires it
389
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
390
+ if self.config.n_query_groups != self.config.n_head and (
391
+ input_pos is None or self.config.n_query_groups != 1
392
+ ):
393
+ k = k.expand(
394
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
395
+ )
396
+ v = v.expand(
397
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
398
+ )
399
+
400
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
401
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
402
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
403
+
404
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
405
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
406
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
407
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
408
+
409
+ if input_pos is not None:
410
+ if not isinstance(self.kv_cache, KVCache):
411
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
412
+ k, v = self.kv_cache(input_pos, k, v)
413
+
414
+ y = self.scaled_dot_product_attention(q, k, v, mask)
415
+
416
+ y = y.reshape(
417
+ B, T, self.config.head_size * self.config.n_head
418
+ ) # re-assemble all head outputs side by side
419
+
420
+ # output projection
421
+ return self.proj(y)
422
+
423
+ def scaled_dot_product_attention(
424
+ self,
425
+ q: torch.Tensor,
426
+ k: torch.Tensor,
427
+ v: torch.Tensor,
428
+ mask: Optional[torch.Tensor] = None,
429
+ ) -> torch.Tensor:
430
+ scale = 1.0 / math.sqrt(self.config.head_size)
431
+ y = torch.nn.functional.scaled_dot_product_attention(
432
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
433
+ )
434
+ return y.transpose(1, 2)
435
+
436
+ def build_kv_cache(
437
+ self,
438
+ batch_size: int,
439
+ max_seq_length: int,
440
+ rope_cache_length: Optional[int] = None,
441
+ device: Optional[torch.device] = None,
442
+ dtype: Optional[torch.dtype] = None,
443
+ ) -> "KVCache":
444
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
445
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
446
+ if rope_cache_length is None:
447
+ if self.config.rotary_percentage != 1.0:
448
+ raise TypeError(
449
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
450
+ )
451
+ k_shape = v_shape
452
+ else:
453
+ k_shape = (
454
+ batch_size,
455
+ heads,
456
+ max_seq_length,
457
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
458
+ )
459
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
460
+
461
+
462
+ class GptNeoxMLP(nn.Module):
463
+ def __init__(self, config: Config) -> None:
464
+ super().__init__()
465
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
466
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
467
+
468
+ self.config = config
469
+
470
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
471
+ x = self.fc(x)
472
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
473
+ return self.proj(x)
474
+
475
+
476
+ class LLaMAMLP(nn.Module):
477
+ def __init__(self, config: Config) -> None:
478
+ super().__init__()
479
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
480
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
481
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
482
+
483
+ self.config = config
484
+
485
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
486
+ x_fc_1 = self.fc_1(x)
487
+ x_fc_2 = self.fc_2(x)
488
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
489
+ return self.proj(x)
490
+
491
+
492
+ class whisperMLP(nn.Module):
493
+ def __init__(self, config: Config) -> None:
494
+ super().__init__()
495
+ self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
496
+ self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
497
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
498
+
499
+ self.config = config
500
+
501
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
502
+ x_fc_1 = self.fc_1(x)
503
+ x_fc_2 = self.fc_2(x)
504
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
505
+ return self.proj(x)
506
+
507
+
508
+ class GemmaMLP(LLaMAMLP):
509
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
510
+ x_fc_1 = self.fc_1(x)
511
+ x_fc_2 = self.fc_2(x)
512
+ x = (
513
+ torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
514
+ * x_fc_2
515
+ )
516
+ return self.proj(x)
517
+
518
+
519
+ class LLaMAMoE(nn.Module):
520
+ def __init__(self, config: Config) -> None:
521
+ super().__init__()
522
+ self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
523
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
524
+
525
+ self.config = config
526
+
527
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
528
+ """
529
+ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
530
+ See also figure 1 in https://arxiv.org/abs/2211.15841
531
+ """
532
+ B, T, C = (
533
+ x.size()
534
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
535
+ x = x.view(-1, C) # (B*T, C)
536
+ router = self.gate(x) # (B*T, n_expert)
537
+ probs, indices = torch.topk(
538
+ router, self.config.n_expert_per_token
539
+ ) # (B*T, n_expert_per_token)
540
+ probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
541
+ masks = indices.unsqueeze(-1) == torch.arange(
542
+ self.config.n_expert, device=x.device
543
+ )
544
+ masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
545
+ y = torch.zeros_like(x) # (B*T, C)
546
+ for mask, expert in zip(masks, self.experts):
547
+ token_idx, expert_idx = torch.where(mask)
548
+ y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
549
+ return y.view(B, T, C)
550
+
551
+
552
+ def build_rope_cache(
553
+ seq_len: int,
554
+ n_elem: int,
555
+ device: Optional[torch.device] = None,
556
+ base: int = 10000,
557
+ condense_ratio: int = 1,
558
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
559
+ """Enhanced Transformer with Rotary Position Embedding.
560
+
561
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
562
+ transformers/rope/__init__.py. MIT License:
563
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
564
+ """
565
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
566
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
567
+
568
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
569
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
570
+
571
+ # Calculate the product of position index and $\theta_i$
572
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
573
+
574
+ return torch.cos(idx_theta), torch.sin(idx_theta)
575
+
576
+
577
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
578
+ head_size = x.size(-1)
579
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
580
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
581
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
582
+ roped = (x * cos) + (rotated * sin)
583
+ return roped.to(dtype=x.dtype)
584
+
585
+
586
+ class KVCache(nn.Module):
587
+ def __init__(
588
+ self,
589
+ k_shape: Tuple[int, int, int, int],
590
+ v_shape: Tuple[int, int, int, int],
591
+ device: Optional[torch.device] = None,
592
+ dtype: Optional[torch.dtype] = None,
593
+ ) -> None:
594
+ super().__init__()
595
+ self.register_buffer(
596
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
597
+ )
598
+ self.register_buffer(
599
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
600
+ )
601
+
602
+ def forward(
603
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
604
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
605
+ # move the buffer to the activation dtype for when AMP is used
606
+ self.k = self.k.to(k.dtype)
607
+ self.v = self.v.to(v.dtype)
608
+ # update the cache
609
+ k = self.k.index_copy_(2, input_pos, k)
610
+ v = self.v.index_copy_(2, input_pos, v)
611
+ return k, v
612
+
613
+ def reset_parameters(self) -> None:
614
+ torch.nn.init.zeros_(self.k)
615
+ torch.nn.init.zeros_(self.v)
616
+
617
+
618
+ def build_mask_cache(
619
+ max_seq_length: int, device: Optional[torch.device] = None
620
+ ) -> torch.Tensor:
621
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
622
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
623
+
624
+
625
+ class RMSNorm(torch.nn.Module):
626
+ """Root Mean Square Layer Normalization.
627
+
628
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
629
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
630
+ """
631
+
632
+ def __init__(
633
+ self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
634
+ ) -> None:
635
+ super().__init__()
636
+ self.weight = torch.nn.Parameter(torch.ones(size))
637
+ self.eps = eps
638
+ self.dim = dim
639
+ self.add_unit_offset = add_unit_offset
640
+
641
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
642
+ dtype = x.dtype
643
+ x = x.float()
644
+ # NOTE: the original RMSNorm paper implementation is not equivalent
645
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
646
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
647
+ x_normed = x_normed.to(dtype=dtype)
648
+ if self.add_unit_offset:
649
+ # Gemma model requires a unit offset
650
+ # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
651
+ return x_normed * (1 + self.weight)
652
+ return x_normed * self.weight
653
+
654
+ def reset_parameters(self) -> None:
655
+ torch.nn.init.ones_(self.weight)
litgpt/tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
12
+ checkpoint_dir = Path(checkpoint_dir)
13
+ if not checkpoint_dir.exists():
14
+ raise NotADirectoryError(
15
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
16
+ )
17
+
18
+
19
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
20
+ self.bos_id = None
21
+ self.eos_id = None
22
+
23
+ # some checkpoints have both files, `.json` takes precedence
24
+ if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
25
+ from tokenizers import Tokenizer as HFTokenizer
26
+
27
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
28
+ self.backend = "huggingface"
29
+
30
+ if (
31
+ special_tokens_path := checkpoint_dir / "tokenizer_config.json"
32
+ ).is_file():
33
+ with open(special_tokens_path, encoding="utf-8") as fp:
34
+ config = json.load(fp)
35
+ bos_token = config.get("bos_token")
36
+ eos_token = config.get("eos_token")
37
+ if bos_token is not None and isinstance(bos_token, dict):
38
+ bos_token = bos_token.get("content")
39
+ if eos_token is not None and isinstance(eos_token, dict):
40
+ eos_token = eos_token.get("content")
41
+ self.bos_id = (
42
+ self.token_to_id(bos_token) if bos_token is not None else None
43
+ )
44
+ self.eos_id = (
45
+ self.token_to_id(eos_token) if eos_token is not None else None
46
+ )
47
+ if (
48
+ special_tokens_path := checkpoint_dir / "generation_config.json"
49
+ ).is_file():
50
+ with open(special_tokens_path, encoding="utf-8") as fp:
51
+ config = json.load(fp)
52
+ if self.bos_id is None:
53
+ self.bos_id = config.get("bos_token_id")
54
+ if self.eos_id is None:
55
+ self.eos_id = config.get("eos_token_id")
56
+
57
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
58
+ from sentencepiece import SentencePieceProcessor
59
+
60
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
61
+ self.backend = "sentencepiece"
62
+ self.bos_id = self.processor.bos_id()
63
+ self.eos_id = self.processor.eos_id()
64
+ else:
65
+ raise NotImplementedError
66
+
67
+ @property
68
+ def vocab_size(self) -> int:
69
+ if self.backend == "huggingface":
70
+ return self.processor.get_vocab_size(with_added_tokens=False)
71
+ if self.backend == "sentencepiece":
72
+ return self.processor.vocab_size()
73
+ raise RuntimeError
74
+
75
+ def token_to_id(self, token: str) -> int:
76
+ if self.backend == "huggingface":
77
+ id_ = self.processor.token_to_id(token)
78
+ elif self.backend == "sentencepiece":
79
+ id_ = self.processor.piece_to_id(token)
80
+ else:
81
+ raise RuntimeError
82
+ if id_ is None:
83
+ raise ValueError(f"token {token!r} not found in the collection.")
84
+ return id_
85
+
86
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
87
+ if not (
88
+ tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
89
+ ).is_file():
90
+ return False
91
+ with open(tokenizer_config_path, encoding="utf-8") as fp:
92
+ config = json.load(fp)
93
+ if "add_bos_token" in config:
94
+ return config["add_bos_token"]
95
+ # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
96
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
97
+ return config.get("tokenizer_class") == "LlamaTokenizer"
98
+
99
+ def encode(
100
+ self,
101
+ string: str,
102
+ device: Optional[torch.device] = None,
103
+ bos: Optional[bool] = None,
104
+ eos: bool = False,
105
+ max_length: int = -1,
106
+ ) -> torch.Tensor:
107
+ if self.backend == "huggingface":
108
+ tokens = self.processor.encode(string).ids
109
+ elif self.backend == "sentencepiece":
110
+ tokens = self.processor.encode(string)
111
+ else:
112
+ raise RuntimeError
113
+ if bos or (bos is None and self.use_bos):
114
+ bos_id = self.bos_id
115
+ if bos_id is None:
116
+ raise NotImplementedError(
117
+ "This tokenizer does not have a defined a bos token"
118
+ )
119
+ if tokens[0] != bos_id:
120
+ tokens = [bos_id] + tokens
121
+ if tokens is None:
122
+ raise ValueError("`tokens` is None")
123
+
124
+ if eos and (not tokens or tokens[-1] != self.eos_id):
125
+ tokens = tokens + [self.eos_id]
126
+ if max_length > 0:
127
+ tokens = tokens[:max_length]
128
+ return torch.tensor(tokens, dtype=torch.int, device=device)
129
+
130
+ def decode(self, tensor: torch.Tensor) -> str:
131
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
132
+ return self.processor.decode(tokens)
litgpt/utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Utility functions for training and inference."""
4
+ import inspect
5
+ import math
6
+ import os
7
+ import pickle
8
+ import shutil
9
+ import sys
10
+ from dataclasses import asdict, is_dataclass
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import lightning as L
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.utils._device
30
+ import yaml
31
+ from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
32
+ from lightning.fabric.strategies import FSDPStrategy
33
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
34
+ from lightning.pytorch.loggers import WandbLogger
35
+ from lightning.pytorch.cli import instantiate_class
36
+ from torch.serialization import normalize_storage_type
37
+ from typing_extensions import Self
38
+
39
+ if TYPE_CHECKING:
40
+ from litgpt import GPT, Config
41
+
42
+
43
+ def init_out_dir(out_dir: Path) -> Path:
44
+ if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
45
+ return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
46
+ return out_dir
47
+
48
+
49
+ def find_resume_path(
50
+ resume: Union[bool, Literal["auto"], Path], out_dir: Path
51
+ ) -> Optional[Path]:
52
+ if not resume or isinstance(resume, Path):
53
+ return resume
54
+
55
+ resume_path = max(
56
+ out_dir.rglob("step-*/*.pth"),
57
+ key=(lambda p: int(p.parent.name.split("-")[1])),
58
+ default=None,
59
+ )
60
+ if resume == "auto":
61
+ return resume_path
62
+ if resume is True and resume_path is None:
63
+ raise FileNotFoundError(
64
+ f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
65
+ )
66
+ return resume_path
67
+
68
+
69
+ def find_multiple(n: int, k: int) -> int:
70
+ assert k > 0
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
77
+ total = 0
78
+ for p in module.parameters():
79
+ if requires_grad is None or p.requires_grad == requires_grad:
80
+ if hasattr(p, "quant_state"):
81
+ # bitsandbytes 4bit layer support
82
+ total += math.prod(p.quant_state.shape)
83
+ else:
84
+ total += p.numel()
85
+ return total
86
+
87
+
88
+ def reset_parameters(module: nn.Module) -> None:
89
+ """Calls `reset_parameters` on the module and all its submodules."""
90
+ for mod in module.modules():
91
+ if callable(getattr(mod, "reset_parameters", None)):
92
+ mod.reset_parameters()
93
+
94
+
95
+ def check_valid_checkpoint_dir(
96
+ checkpoint_dir: Path,
97
+ model_filename: str = "lit_model.pth",
98
+ verbose: bool = True,
99
+ raise_error: bool = False,
100
+ ) -> None:
101
+ files = {
102
+ model_filename: (checkpoint_dir / model_filename).is_file(),
103
+ "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
104
+ "tokenizer.json OR tokenizer.model": (
105
+ checkpoint_dir / "tokenizer.json"
106
+ ).is_file()
107
+ or (checkpoint_dir / "tokenizer.model").is_file(),
108
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
109
+ }
110
+ if checkpoint_dir.is_dir():
111
+ if all(files.values()):
112
+ # we're good
113
+ return
114
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
115
+ else:
116
+ problem = " is not a checkpoint directory"
117
+
118
+ # list locally available checkpoints
119
+ available = list(Path("checkpoints").glob("*/*"))
120
+ if available:
121
+ options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
122
+ extra = f"\nYou have downloaded locally:{options}\n"
123
+ else:
124
+ extra = ""
125
+
126
+ if verbose:
127
+ error_message = (
128
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
129
+ "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
130
+ f"{extra}\nSee all download options by running:\n litgpt download"
131
+ )
132
+ print(error_message, file=sys.stderr)
133
+
134
+ if raise_error:
135
+ raise FileNotFoundError(
136
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
137
+ )
138
+ else:
139
+ raise SystemExit(1)
140
+
141
+
142
+ class SavingProxyForStorage:
143
+ def __init__(self, obj, saver, protocol_version=5):
144
+ self.protocol_version = protocol_version
145
+ self.saver = saver
146
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
147
+ raise TypeError(f"expected storage, not {type(obj)}")
148
+
149
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
150
+ if isinstance(obj, torch.storage.TypedStorage):
151
+ # PT upstream wants to deprecate this eventually...
152
+ storage = obj._untyped_storage
153
+ storage_type_str = obj._pickle_storage_type()
154
+ storage_type = getattr(torch, storage_type_str)
155
+ storage_numel = obj._size()
156
+ else:
157
+ storage = obj
158
+ storage_type = normalize_storage_type(type(obj))
159
+ storage_numel = storage.nbytes()
160
+
161
+ storage_key = saver._write_storage_and_return_key(storage)
162
+ location = torch.serialization.location_tag(storage)
163
+
164
+ self.storage_info = (
165
+ "storage",
166
+ storage_type,
167
+ storage_key,
168
+ location,
169
+ storage_numel,
170
+ )
171
+
172
+ def __reduce_ex__(self, protocol_version):
173
+ assert False, "this should be handled with out of band"
174
+
175
+
176
+ class SavingProxyForTensor:
177
+ def __init__(self, tensor, saver, protocol_version=5):
178
+ self.protocol_version = protocol_version
179
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
180
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
181
+ # for Tensors with Python attributes
182
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
183
+ assert isinstance(
184
+ storage, torch.storage.TypedStorage
185
+ ), "Please check for updates"
186
+ storage_proxy = SavingProxyForStorage(
187
+ storage, saver, protocol_version=protocol_version
188
+ )
189
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
190
+ else:
191
+ (storage, *other_reduce_args) = reduce_args
192
+ assert isinstance(
193
+ storage, torch.storage.TypedStorage
194
+ ), "Please check for updates"
195
+ storage_proxy = SavingProxyForStorage(
196
+ storage, saver, protocol_version=protocol_version
197
+ )
198
+ self.reduce_args = (storage_proxy, *other_reduce_args)
199
+
200
+ def __reduce_ex__(self, protocol_version):
201
+ if protocol_version != self.protocol_version:
202
+ raise RuntimeError(
203
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
204
+ )
205
+ return self.reduce_ret_fn, self.reduce_args
206
+
207
+
208
+ class IncrementalPyTorchPickler(pickle.Pickler):
209
+ def __init__(self, saver, *args, **kwargs):
210
+ super().__init__(*args, **kwargs)
211
+ self.storage_dtypes = {}
212
+ self.saver = saver
213
+ self.id_map = {}
214
+
215
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
216
+ def persistent_id(self, obj):
217
+ # FIXME: the docs say that persistent_id should only return a string
218
+ # but torch store returns tuples. This works only in the binary protocol
219
+ # see
220
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
221
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
222
+ if isinstance(obj, SavingProxyForStorage):
223
+ return obj.storage_info
224
+
225
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
226
+ if isinstance(obj, torch.storage.TypedStorage):
227
+ # TODO: Once we decide to break serialization FC, this case
228
+ # can be deleted
229
+ storage = obj._untyped_storage
230
+ storage_dtype = obj.dtype
231
+ storage_type_str = obj._pickle_storage_type()
232
+ storage_type = getattr(torch, storage_type_str)
233
+ storage_numel = obj._size()
234
+
235
+ else:
236
+ storage = obj
237
+ storage_dtype = torch.uint8
238
+ storage_type = normalize_storage_type(type(obj))
239
+ storage_numel = storage.nbytes()
240
+
241
+ # If storage is allocated, ensure that any other saved storages
242
+ # pointing to the same data all have the same dtype. If storage is
243
+ # not allocated, don't perform this check
244
+ if storage.data_ptr() != 0:
245
+ if storage.data_ptr() in self.storage_dtypes:
246
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
247
+ raise RuntimeError(
248
+ "Cannot save multiple tensors or storages that view the same data as different types"
249
+ )
250
+ else:
251
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
252
+
253
+ storage_key = self.id_map.get(storage._cdata)
254
+ if storage_key is None:
255
+ storage_key = self.saver._write_storage_and_return_key(storage)
256
+ self.id_map[storage._cdata] = storage_key
257
+ location = torch.serialization.location_tag(storage)
258
+
259
+ return ("storage", storage_type, storage_key, location, storage_numel)
260
+
261
+ return None
262
+
263
+
264
+ class incremental_save:
265
+ def __init__(self, name):
266
+ self.name = name
267
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
268
+ self.has_saved = False
269
+ self.next_key = 0
270
+
271
+ def __enter__(self):
272
+ return self
273
+
274
+ def store_early(self, tensor):
275
+ if isinstance(tensor, torch.Tensor):
276
+ return SavingProxyForTensor(tensor, self)
277
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
278
+
279
+ def save(self, obj):
280
+ if self.has_saved:
281
+ raise RuntimeError("have already saved")
282
+ # Write the pickle data for `obj`
283
+ data_buf = BytesIO()
284
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
285
+ pickler.dump(obj)
286
+ data_value = data_buf.getvalue()
287
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
288
+ self.has_saved = True
289
+
290
+ def _write_storage_and_return_key(self, storage):
291
+ if self.has_saved:
292
+ raise RuntimeError("have already saved")
293
+ key = self.next_key
294
+ self.next_key += 1
295
+ name = f"data/{key}"
296
+ if storage.device.type != "cpu":
297
+ storage = storage.cpu()
298
+ num_bytes = storage.nbytes()
299
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
300
+ return key
301
+
302
+ def __exit__(self, type, value, traceback):
303
+ self.zipfile.write_end_of_file()
304
+
305
+
306
+ T = TypeVar("T")
307
+
308
+
309
+ def chunked_cross_entropy(
310
+ logits: Union[torch.Tensor, List[torch.Tensor]],
311
+ targets: torch.Tensor,
312
+ chunk_size: int = 128,
313
+ ignore_index: int = -100,
314
+ ) -> torch.Tensor:
315
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
316
+ # the memory usage in fine-tuning settings with low number of parameters.
317
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
318
+ # the memory spike's magnitude
319
+
320
+ # lm_head was chunked (we are fine-tuning)
321
+ if isinstance(logits, list):
322
+ # don't want to chunk cross entropy
323
+ if chunk_size == 0:
324
+ logits = torch.cat(logits, dim=1)
325
+ logits = logits.reshape(-1, logits.size(-1))
326
+ targets = targets.reshape(-1)
327
+ return torch.nn.functional.cross_entropy(
328
+ logits, targets, ignore_index=ignore_index
329
+ )
330
+
331
+ # chunk cross entropy
332
+ logit_chunks = [
333
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
334
+ ]
335
+ target_chunks = [
336
+ target_chunk.reshape(-1)
337
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
338
+ ]
339
+ loss_chunks = [
340
+ torch.nn.functional.cross_entropy(
341
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
342
+ )
343
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
344
+ ]
345
+ non_masked_elems = (targets != ignore_index).sum()
346
+ # See [non_masked_elems div note]
347
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
348
+ torch.ones_like(non_masked_elems)
349
+ )
350
+
351
+ # no chunking at all
352
+ logits = logits.reshape(-1, logits.size(-1))
353
+ targets = targets.reshape(-1)
354
+ if chunk_size == 0:
355
+ return torch.nn.functional.cross_entropy(
356
+ logits, targets, ignore_index=ignore_index
357
+ )
358
+
359
+ # lm_head wasn't chunked, chunk cross entropy
360
+ logit_chunks = logits.split(chunk_size)
361
+ target_chunks = targets.split(chunk_size)
362
+ loss_chunks = [
363
+ torch.nn.functional.cross_entropy(
364
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
365
+ )
366
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
367
+ ]
368
+ non_masked_elems = (targets != ignore_index).sum()
369
+ # [non_masked_elems div note]:
370
+ # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
371
+ # results in a python int which is then passed back to torch division. By using the
372
+ # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
373
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
374
+ torch.ones_like(non_masked_elems)
375
+ )
376
+
377
+
378
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
379
+ for checkpoint_name, attribute_name in mapping.items():
380
+ full_checkpoint_name = prefix + checkpoint_name
381
+ if full_checkpoint_name in state_dict:
382
+ full_attribute_name = prefix + attribute_name
383
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
384
+ return state_dict
385
+
386
+
387
+ def get_default_supported_precision(training: bool) -> str:
388
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
389
+
390
+ Args:
391
+ training: `-mixed` or `-true` version of the precision to use
392
+
393
+ Returns:
394
+ default precision that is suitable for the task and is supported by the hardware
395
+ """
396
+ from lightning.fabric.accelerators import MPSAccelerator
397
+
398
+ if MPSAccelerator.is_available() or (
399
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
400
+ ):
401
+ return "16-mixed" if training else "16-true"
402
+ return "bf16-mixed" if training else "bf16-true"
403
+
404
+
405
+ def load_checkpoint(
406
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
407
+ ) -> None:
408
+ if isinstance(fabric.strategy, FSDPStrategy):
409
+ fabric.load_raw(checkpoint_path, model, strict=strict)
410
+ else:
411
+ state_dict = lazy_load(checkpoint_path)
412
+ state_dict = state_dict.get("model", state_dict)
413
+ model.load_state_dict(state_dict, strict=strict)
414
+
415
+
416
+ def flops_per_param(
417
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
418
+ ) -> int:
419
+ flops_per_token = (
420
+ 2 * n_params
421
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
422
+ # this assumes that all samples have a fixed length equal to the block size
423
+ # which is most likely false during finetuning
424
+ flops_per_seq = flops_per_token * max_seq_length
425
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
426
+ return flops_per_seq + attn_flops_per_seq
427
+
428
+
429
+ def estimate_flops(model: "GPT", training: bool) -> int:
430
+ """Measures estimated FLOPs for MFU.
431
+
432
+ Refs:
433
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
434
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
435
+ """
436
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
437
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
438
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
439
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
440
+ n_trainable_params = num_parameters(model, requires_grad=True)
441
+ trainable_flops = flops_per_param(
442
+ model.max_seq_length,
443
+ model.config.n_layer,
444
+ model.config.n_embd,
445
+ n_trainable_params,
446
+ )
447
+ # forward + backward + gradients (assumes no gradient accumulation)
448
+ ops_per_step = 3 if training else 1
449
+ n_frozen_params = num_parameters(model, requires_grad=False)
450
+ frozen_flops = flops_per_param(
451
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
452
+ )
453
+ # forward + backward
454
+ frozen_ops_per_step = 2 if training else 1
455
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
456
+
457
+
458
+ class CycleIterator:
459
+ """An iterator that cycles through an iterable indefinitely.
460
+
461
+ Example:
462
+ >>> iterator = CycleIterator([1, 2, 3])
463
+ >>> [next(iterator) for _ in range(5)]
464
+ [1, 2, 3, 1, 2]
465
+
466
+ Note:
467
+ Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
468
+ """
469
+
470
+ def __init__(self, iterable: Iterable) -> None:
471
+ self.iterable = iterable
472
+ self.epoch = 0
473
+ self._iterator = None
474
+
475
+ def __next__(self) -> Any:
476
+ if self._iterator is None:
477
+ self._iterator = iter(self.iterable)
478
+ try:
479
+ return next(self._iterator)
480
+ except StopIteration:
481
+ self._iterator = iter(self.iterable)
482
+ self.epoch += 1
483
+ return next(self._iterator)
484
+
485
+ def __iter__(self) -> Self:
486
+ return self
487
+
488
+
489
+ def copy_config_files(source_dir: Path, out_dir: Path) -> None:
490
+ """Copies the specified configuration and tokenizer files into the output directory."""
491
+
492
+ config_files = ["config.json", "generation_config.json", "model_config.yaml"]
493
+ tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
494
+
495
+ for file_name in config_files + tokenizer_files:
496
+ src_path = source_dir / file_name
497
+ if src_path.exists():
498
+ shutil.copy(src_path, out_dir)
499
+
500
+
501
+ def CLI(*args: Any, **kwargs: Any) -> Any:
502
+ from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
503
+
504
+ set_docstring_parse_options(attribute_docstrings=True)
505
+ set_config_read_mode(urls_enabled=True)
506
+
507
+ return CLI(*args, **kwargs)
508
+
509
+
510
+ def capture_hparams() -> Dict[str, Any]:
511
+ """Captures the local variables ('hyperparameters') from where this function gets called."""
512
+ caller_frame = inspect.currentframe().f_back
513
+ locals_of_caller = caller_frame.f_locals
514
+ hparams = {}
515
+ for name, value in locals_of_caller.items():
516
+ if value is None or isinstance(value, (int, float, str, bool, Path)):
517
+ hparams[name] = value
518
+ elif is_dataclass(value):
519
+ hparams[name] = asdict(value)
520
+ else:
521
+ hparams[name] = str(value)
522
+ return hparams
523
+
524
+
525
+ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
526
+ """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
527
+ from jsonargparse import capture_parser
528
+
529
+ # TODO: Make this more robust
530
+ # This hack strips away the subcommands from the top-level CLI
531
+ # to parse the file as if it was called as a script
532
+ known_commands = [
533
+ ("finetune_full",), # For subcommands, use `("finetune", "full")` etc
534
+ ("finetune_lora",),
535
+ ("finetune_adapter",),
536
+ ("finetune_adapter_v2",),
537
+ ("finetune",),
538
+ ("pretrain",),
539
+ ]
540
+ for known_command in known_commands:
541
+ unwanted = slice(1, 1 + len(known_command))
542
+ if tuple(sys.argv[unwanted]) == known_command:
543
+ sys.argv[unwanted] = []
544
+
545
+ parser = capture_parser(lambda: CLI(function))
546
+ config = parser.parse_args()
547
+ parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
548
+
549
+
550
+ def save_config(config: "Config", checkpoint_dir: Path) -> None:
551
+ config_dict = asdict(config)
552
+ with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
553
+ yaml.dump(config_dict, fp)
554
+
555
+
556
+ def parse_devices(devices: Union[str, int]) -> int:
557
+ if devices in (-1, "auto"):
558
+ return torch.cuda.device_count() or 1
559
+ if isinstance(devices, int) and devices > 0:
560
+ return devices
561
+ raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
562
+
563
+
564
+ def choose_logger(
565
+ logger_name: Literal["csv", "tensorboard", "wandb"],
566
+ out_dir: Path,
567
+ name: str,
568
+ log_interval: int = 1,
569
+ resume: Optional[bool] = None,
570
+ **kwargs: Any,
571
+ ):
572
+ if logger_name == "csv":
573
+ return CSVLogger(
574
+ root_dir=(out_dir / "logs"),
575
+ name="csv",
576
+ flush_logs_every_n_steps=log_interval,
577
+ **kwargs,
578
+ )
579
+ if logger_name == "tensorboard":
580
+ return TensorBoardLogger(
581
+ root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
582
+ )
583
+ if logger_name == "wandb":
584
+ return WandbLogger(project=name, resume=resume, **kwargs)
585
+ raise ValueError(
586
+ f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
587
+ )
588
+
589
+
590
+ def get_argument_names(cls):
591
+ sig = inspect.signature(cls.__init__)
592
+ return {
593
+ name
594
+ for name, param in sig.parameters.items()
595
+ if param.kind
596
+ in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
597
+ }
598
+
599
+
600
+ def instantiate_bnb_optimizer(optimizer, model_parameters):
601
+ if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
602
+ isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
603
+ ):
604
+ raise ValueError(
605
+ "The chosen quantization format only supports the AdamW optimizer."
606
+ )
607
+
608
+ import bitsandbytes as bnb
609
+
610
+ if isinstance(optimizer, str):
611
+ optimizer = bnb.optim.PagedAdamW(model_parameters)
612
+ else:
613
+ optim_args = get_argument_names(bnb.optim.PagedAdamW)
614
+ allowed_kwargs = {
615
+ key: optimizer["init_args"][key]
616
+ for key in optim_args & optimizer["init_args"].keys()
617
+ }
618
+ optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
619
+ return optimizer
620
+
621
+
622
+ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
623
+ if isinstance(optimizer, str):
624
+ optimizer_cls = getattr(torch.optim, optimizer)
625
+ optimizer = optimizer_cls(model_parameters, **kwargs)
626
+ else:
627
+ optimizer = dict(optimizer) # copy
628
+ optimizer["init_args"].update(kwargs)
629
+ optimizer = instantiate_class(model_parameters, optimizer)
630
+ return optimizer
631
+
632
+
633
+ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
634
+ new_checkpoint_dir = "checkpoints" / checkpoint_dir
635
+ should_return_new_dir = (
636
+ not checkpoint_dir.is_dir()
637
+ and checkpoint_dir.parts[0] != "checkpoints"
638
+ and not checkpoint_dir.is_absolute()
639
+ and new_checkpoint_dir.exists()
640
+ )
641
+ return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
models/README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: any-to-any
4
+ library_name: mini-omni2
5
+ ---
6
+
7
+ # Mini-Omni2
8
+
9
+ <!-- <p align="center">
10
+ <img src="./data/figures/title.png" width="100%"/>
11
+ </p> -->
12
+
13
+
14
+ <p align="center">
15
+ 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni2">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni2">Github</a>
16
+ | 📑 <a href="https://arxiv.org/abs/2410.11190">Technical report</a>
17
+ </p>
18
+
19
+ Mini-Omni2 is an **omni-interactive** model. It can **understand image, audio and text inputs and has end-to-end voice conversations with users**. Featuring **real-time voice output**, **omni-capable multimodal understanding** and flexible interaction **ability with interruption mechanism while speaking**.
20
+
21
+ <p align="center">
22
+ <img src="./data/figures/framework.jpeg" width="100%"/>
23
+ </p>
24
+
25
+
26
+
27
+ ## Updates
28
+
29
+ - **2024.10:** Release the model, technical report, inference and chat demo code.
30
+
31
+ ## Features
32
+ ✅ **Multimodal interaction**: with the ability to understand images, speech and text, just like GPT-4o.
33
+
34
+ ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required, just like [Mini-Omni](https://github.com/gpt-omni/mini-omni).
35
+
36
+ <!-- ✅ **Streaming audio output**: with first-chunk latency of audio stream less than 0.3s. -->
37
+
38
+ <!-- ✅ **Duplex interaction**: hearing while speaking, it can be interrupted by key words like "stop omni". -->
39
+
40
+
41
+ ## Demo
42
+
43
+ NOTE: need to unmute first.
44
+
45
+ https://github.com/user-attachments/assets/ad97ca7f-f8b4-40c3-a7e8-fa54b4edf155
46
+
47
+
48
+ ## ToDo
49
+ - [ ] update interruption mechanism
50
+
51
+
52
+ ## Install
53
+
54
+ Create a new conda environment and install the required packages:
55
+
56
+ ```sh
57
+ conda create -n omni python=3.10
58
+ conda activate omni
59
+
60
+ git clone https://github.com/gpt-omni/mini-omni2.git
61
+ cd mini-omni2
62
+ pip install -r requirements.txt
63
+ ```
64
+
65
+ ## Quick start
66
+
67
+ **Interactive demo**
68
+
69
+ - start server
70
+
71
+ NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
72
+
73
+ ```sh
74
+ sudo apt-get install ffmpeg
75
+ conda activate omni
76
+ cd mini-omni2
77
+ python3 server.py --ip '0.0.0.0' --port 60808
78
+ ```
79
+
80
+
81
+ - run streamlit demo
82
+
83
+ NOTE: you need to run streamlit **locally** with PyAudio installed.
84
+
85
+ ```sh
86
+ pip install PyAudio==0.2.14
87
+ API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
88
+ ```
89
+
90
+
91
+ **Local test**
92
+
93
+ ```sh
94
+ conda activate omni
95
+ cd mini-omni2
96
+ # test run the preset audio samples and questions
97
+ python inference_vision.py
98
+ ```
99
+
100
+ ## Mini-Omni2 Overview
101
+
102
+ **1. Multimodal Modeling**:
103
+ We use multiple sequences as the input and output of the model. In the input part, we will concatenate image, audio and text features to perform a series of comprehensive tasks, as shown in the following figures. In the output part, we use text-guided delayed parallel output to generate real-time speech responses.
104
+ <p align="center">
105
+ <img src="./data/figures/inputids.png" width="100%"/>
106
+ </p>
107
+
108
+ **2. Multi-stage Training**:
109
+ We propose an efficient alignment training method and conduct encoder adaptation, modal alignment, and multimodal fine-tuning respectively in the three-stage training.
110
+ <p align="center">
111
+ <img src="./data/figures/training.jpeg" width="100%"/>
112
+ </p>
113
+
114
+ <!-- **3. Cases**:
115
+ Here are more cases of Mini-Omni2:
116
+ <p align="center">
117
+ <img src="./data/figures/samples.png" width="100%"/>
118
+ </p> -->
119
+
120
+ ## FAQ
121
+
122
+ **1. Does the model support other languages?**
123
+
124
+ No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
125
+
126
+ **2. Error: can not run streamlit in local browser, with remote streamlit server**
127
+
128
+ You need start streamlit **locally** with PyAudio installed.
129
+
130
+
131
+ ## Acknowledgements
132
+
133
+ - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
134
+ - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
135
+ - [whisper](https://github.com/openai/whisper/) for audio encoding.
136
+ - [clip](https://github.com/openai/CLIP) for image encoding.
137
+ - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
138
+ - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
139
+ - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
140
+
141
+ <!-- ## Star History
142
+
143
+ [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni2&type=Date)](https://star-history.com/#gpt-omni/mini-omni2&Date)
models/ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
+ size 353976522
models/data/figures/framework.jpeg ADDED

Git LFS Details

  • SHA256: bc668450030500a62ddbb7cf6ea170f0b53da7e3e5506d01a0dc6f2ec690fd1a
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
models/data/figures/inputids.png ADDED

Git LFS Details

  • SHA256: ad4cf663684c53f72952b13f52ea93fcbe19e287301b3decfcd917de9e23f312
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
models/data/figures/samples.png ADDED

Git LFS Details

  • SHA256: e63a8cbc2859304cb9c50b831366ac8804ad0326b6ae4897d08f8ab0e1eb63c6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
models/data/figures/title.png ADDED

Git LFS Details

  • SHA256: 56194a7fd5cfd29d6e2ce574fc7628315d87220adbc7aa2949e579c1a63ed2a3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
models/data/figures/training.jpeg ADDED

Git LFS Details

  • SHA256: fd49f75dbe5838a3e28f02c8f853dec34d0aad8573911d52bd827ab6dae8f9a1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
models/data/omni2-demo.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2098124af391dca9c48854f5686143c137cc069f08b5e457675b9ba744bd2f
3
+ size 11784395
models/hub/.locks/models--hubertsiuzdak--snac_24khz/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40.lock ADDED
File without changes
models/hub/.locks/models--hubertsiuzdak--snac_24khz/a9e7ef62bf7e1eb94d2713721029837aacab3b55.lock ADDED
File without changes
models/hub/models--hubertsiuzdak--snac_24khz/blobs/4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
3
+ size 79488254
models/hub/models--hubertsiuzdak--snac_24khz/blobs/a9e7ef62bf7e1eb94d2713721029837aacab3b55 ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sampling_rate": 24000,
3
+ "encoder_dim": 48,
4
+ "encoder_rates": [2, 4, 8, 8],
5
+ "decoder_dim": 1024,
6
+ "decoder_rates": [8, 8, 4, 2],
7
+ "attn_window_size": null,
8
+ "codebook_size": 4096,
9
+ "codebook_dim": 8,
10
+ "vq_strides": [4, 2, 1],
11
+ "noise": true,
12
+ "depthwise": true
13
+ }
models/hub/models--hubertsiuzdak--snac_24khz/refs/main ADDED
@@ -0,0 +1 @@
 
 
1
+ d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sampling_rate": 24000,
3
+ "encoder_dim": 48,
4
+ "encoder_rates": [2, 4, 8, 8],
5
+ "decoder_dim": 1024,
6
+ "decoder_rates": [8, 8, 4, 2],
7
+ "attn_window_size": null,
8
+ "codebook_size": 4096,
9
+ "codebook_dim": 8,
10
+ "vq_strides": [4, 2, 1],
11
+ "noise": true,
12
+ "depthwise": true
13
+ }
models/hub/models--hubertsiuzdak--snac_24khz/snapshots/d73ad176a12188fcf4f360ba3bf2c2fbbe8f58ec/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b8164cc6606bfa627f1a784734c1e539891518f1191ed9194fe1e3b9b4bff40
3
+ size 79488254
models/hub/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
models/lit_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f3e53aecb30a107ce52917e4f2c6ccb10e4a1457708b6a94fb43c72961a31c5
3
+ size 2814623738
models/model_config.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_qkv_bias: true
2
+ asr_adapter: llamamlp
3
+ attn_dropout: 0.0
4
+ bias: false
5
+ block_size: 2048
6
+ force_align: false
7
+ gelu_approximate: none
8
+ head_size: 64
9
+ hf_config:
10
+ name: Qwen2-0.5B
11
+ org: Qwen
12
+ intermediate_size: 4864
13
+ lm_head_bias: false
14
+ mlp_class_name: LLaMAMLP
15
+ n_embd: 896
16
+ n_expert: 0
17
+ n_expert_per_token: 0
18
+ n_head: 14
19
+ n_layer: 24
20
+ n_query_groups: 2
21
+ name: Qwen2-0.5B
22
+ norm_class_name: RMSNorm
23
+ norm_eps: 1.0e-06
24
+ padded_vocab_size: 181120
25
+ padding_multiple: 512
26
+ parallel_residual: false
27
+ pos_type: rope
28
+ post_adapter: false
29
+ post_adapter_layers: 6
30
+ prompt_vocab_size: null
31
+ rope_base: 1000000
32
+ rope_condense_ratio: 1
33
+ rotary_percentage: 1
34
+ scale_embeddings: false
35
+ shared_attention_norm: false
36
+ tie_word_embeddings: true
37
+ use_pretrain_phoneme_emb: false
38
+ vocab_size: 50254
39
+ text_vocab_size: 152000
40
+ cat_audio_vocab_size: 29120
41
+ audio_vocab_size: 4160
42
+ whisper_adapter_dim: 768
43
+ vision_adapter_dim: 512
models/small.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb26a40bfcfbb3d7e41586205d21c90ffc1de552c15367efb4a723ce11f700f
3
+ size 483586606
models/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
30
+ "bos_token": null,
31
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "<|endoftext|>",
34
+ "errors": "replace",
35
+ "model_max_length": 32768,
36
+ "pad_token": "<|endoftext|>",
37
+ "split_special_tokens": false,
38
+ "tokenizer_class": "Qwen2Tokenizer",
39
+ "unk_token": null
40
+ }
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ streamlit-webrtc
11
+ pydub==0.25.1
12
+ onnxruntime==1.19.0
13
+ librosa==0.10.2.post1
14
+ flask==3.0.3
15
+ fire
16
+ git+https://github.com/mini-omni/CLIP.git
17
+ gradio_webrtc[vad]==0.0.11
18
+ twilio
server.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import traceback
5
+ import numpy as np
6
+ import torch
7
+ import base64
8
+ import io
9
+ import os
10
+ import logging
11
+ import whisper # For audio loading/processing
12
+ import soundfile as sf
13
+ from inference import OmniInference
14
+ import tempfile
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ app = FastAPI()
21
+
22
+ # Add CORS middleware
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ class AudioRequest(BaseModel):
32
+ audio_data: str
33
+ sample_rate: int
34
+
35
+ class AudioResponse(BaseModel):
36
+ audio_data: str
37
+ text: str = ""
38
+
39
+ # Model initialization status
40
+ INITIALIZATION_STATUS = {
41
+ "model_loaded": False,
42
+ "error": None
43
+ }
44
+
45
+ # Global model instance
46
+ model = None
47
+
48
+
49
+
50
+
51
+ def initialize_model():
52
+ """Initialize the OmniInference model"""
53
+ global model, INITIALIZATION_STATUS
54
+ try:
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ logger.info(f"Initializing OmniInference model on device: {device}")
57
+
58
+ ckpt_path = os.path.abspath('models')
59
+ logger.info(f"Loading models from: {ckpt_path}")
60
+
61
+ if not os.path.exists(ckpt_path):
62
+ raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist")
63
+
64
+ model = OmniInference(ckpt_path, device=device)
65
+ model.warm_up()
66
+
67
+ INITIALIZATION_STATUS["model_loaded"] = True
68
+ logger.info("OmniInference model initialized successfully")
69
+ return True
70
+ except Exception as e:
71
+ INITIALIZATION_STATUS["error"] = str(e)
72
+ logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}")
73
+ return False
74
+
75
+ @app.on_event("startup")
76
+ async def startup_event():
77
+ """Initialize model on startup"""
78
+ initialize_model()
79
+
80
+ @app.get("/api/v1/health")
81
+ def health_check():
82
+ """Health check endpoint"""
83
+ status = {
84
+ "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
85
+ "initialization_status": INITIALIZATION_STATUS
86
+ }
87
+
88
+ if model is not None:
89
+ status.update({
90
+ "device": str(model.device),
91
+ "model_loaded": True,
92
+ "warm_up_complete": True
93
+ })
94
+
95
+ return status
96
+
97
+
98
+ @app.post("/api/v1/inference")
99
+ async def inference(request: AudioRequest) -> AudioResponse:
100
+ """Run inference with OmniInference model"""
101
+ if not INITIALIZATION_STATUS["model_loaded"]:
102
+ raise HTTPException(
103
+ status_code=503,
104
+ detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
105
+ )
106
+
107
+
108
+ try:
109
+ logger.info(f"Received inference request with sample rate: {request.sample_rate}")
110
+
111
+ # Decode audio data from base64 to numpy array
112
+ audio_bytes = base64.b64decode(request.audio_data)
113
+ audio_array = np.load(io.BytesIO(audio_bytes))
114
+
115
+ # Save numpy array as temporary WAV file for OmniInference
116
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
117
+ # Convert sample rate if needed (OmniInference expects 16kHz)
118
+ if request.sample_rate != 16000:
119
+ # You might want to add resampling logic here
120
+ logger.warning("Sample rate conversion not implemented. Assuming 16kHz.")
121
+
122
+ # Write WAV file using whisper's audio utilities
123
+ audio_data = whisper.pad_or_trim(audio_array.flatten()) # Flatten to 1D if needed
124
+ # whisper.save_audio(audio_data, temp_wav.name, sampling_rate=16000)
125
+ sf.write(temp_wav.name, audio_data, 16000)
126
+
127
+ # Run inference with streaming
128
+ final_text = ""
129
+
130
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out:
131
+ # Get all results from generator
132
+ for audio_stream, text_stream in model.run_AT_batch_stream(
133
+ temp_wav.name,
134
+ stream_stride=4,
135
+ max_returned_tokens=2048,
136
+ save_path=temp_wav_out.name,
137
+ sample_rate=request.sample_rate
138
+ ):
139
+ if text_stream: # Accumulate non-empty text
140
+ final_text += text_stream
141
+ final_audio, sample_rate = sf.read(temp_wav_out.name)
142
+ assert sample_rate == request.sample_rate
143
+
144
+ # Encode output array to base64
145
+ buffer = io.BytesIO()
146
+ np.save(buffer, final_audio)
147
+ audio_b64 = base64.b64encode(buffer.getvalue()).decode()
148
+
149
+ return AudioResponse(
150
+ audio_data=audio_b64,
151
+ text=final_text.strip()
152
+ )
153
+
154
+ except Exception as e:
155
+ logger.error(f"Inference failed: {str(e)}", exc_info=True)
156
+ raise HTTPException(
157
+ status_code=500,
158
+ detail=str(e)
159
+ )
160
+
161
+
162
+ if __name__ == "__main__":
163
+ import uvicorn
164
+ uvicorn.run(app, host="0.0.0.0", port=8000)
utils/__init__.py ADDED
File without changes