Upload folder using huggingface_hub
Browse files- Qformer.py +1272 -0
- __init__.py +0 -0
- blip2.py +105 -0
- config.json +62 -0
- configuration_videochat2.py +453 -0
- ltm/basis_functions.py +266 -0
- ltm/long_term_attention_gibbs.py +315 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +934 -0
- videochat2_it_hd_mistral.py +418 -0
- vit.py +472 -0
Qformer.py
ADDED
|
@@ -0,0 +1,1272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
* Based on huggingface code base
|
| 8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
from typing import Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import Tensor, device, nn
|
| 16 |
+
import torch.utils.checkpoint
|
| 17 |
+
from torch import nn
|
| 18 |
+
from torch.nn import CrossEntropyLoss
|
| 19 |
+
|
| 20 |
+
# from timm.layers import drop_path
|
| 21 |
+
from transformers.activations import ACT2FN
|
| 22 |
+
from transformers.modeling_outputs import (
|
| 23 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 24 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 25 |
+
CausalLMOutputWithCrossAttentions,
|
| 26 |
+
MaskedLMOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.modeling_utils import (
|
| 29 |
+
PreTrainedModel,
|
| 30 |
+
# apply_chunking_to_forward,
|
| 31 |
+
# find_pruneable_heads_and_indices,
|
| 32 |
+
# prune_linear_layer,
|
| 33 |
+
)
|
| 34 |
+
from transformers.pytorch_utils import (
|
| 35 |
+
# PreTrainedModel,
|
| 36 |
+
apply_chunking_to_forward,
|
| 37 |
+
find_pruneable_heads_and_indices,
|
| 38 |
+
prune_linear_layer,
|
| 39 |
+
)
|
| 40 |
+
from transformers.utils import logging
|
| 41 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 42 |
+
|
| 43 |
+
from functools import partial
|
| 44 |
+
from .ltm.long_term_attention_gibbs import LongTermAttention
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class BertEmbeddings(nn.Module):
|
| 50 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, config):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.word_embeddings = nn.Embedding(
|
| 55 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
| 56 |
+
)
|
| 57 |
+
self.position_embeddings = nn.Embedding(
|
| 58 |
+
config.max_position_embeddings, config.hidden_size
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 62 |
+
# any TensorFlow checkpoint file
|
| 63 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 64 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 65 |
+
|
| 66 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 67 |
+
self.register_buffer(
|
| 68 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
| 69 |
+
)
|
| 70 |
+
self.position_embedding_type = getattr(
|
| 71 |
+
config, "position_embedding_type", "absolute"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.config = config
|
| 75 |
+
|
| 76 |
+
def forward(
|
| 77 |
+
self,
|
| 78 |
+
input_ids=None,
|
| 79 |
+
position_ids=None,
|
| 80 |
+
query_embeds=None,
|
| 81 |
+
past_key_values_length=0,
|
| 82 |
+
):
|
| 83 |
+
if input_ids is not None:
|
| 84 |
+
seq_length = input_ids.size()[1]
|
| 85 |
+
else:
|
| 86 |
+
seq_length = 0
|
| 87 |
+
|
| 88 |
+
if position_ids is None:
|
| 89 |
+
position_ids = self.position_ids[
|
| 90 |
+
:, past_key_values_length : seq_length + past_key_values_length
|
| 91 |
+
].clone()
|
| 92 |
+
|
| 93 |
+
if input_ids is not None:
|
| 94 |
+
embeddings = self.word_embeddings(input_ids)
|
| 95 |
+
if self.position_embedding_type == "absolute":
|
| 96 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 97 |
+
embeddings = embeddings + position_embeddings
|
| 98 |
+
|
| 99 |
+
if query_embeds is not None:
|
| 100 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
| 101 |
+
else:
|
| 102 |
+
embeddings = query_embeds
|
| 103 |
+
|
| 104 |
+
embeddings = self.LayerNorm(embeddings)
|
| 105 |
+
embeddings = self.dropout(embeddings)
|
| 106 |
+
return embeddings
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class BertSelfAttention(nn.Module):
|
| 110 |
+
def __init__(self, config, is_cross_attention):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.config = config
|
| 113 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 114 |
+
config, "embedding_size"
|
| 115 |
+
):
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 118 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 119 |
+
)
|
| 120 |
+
self.num_attention_heads = config.num_attention_heads
|
| 121 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 122 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 123 |
+
self.is_cross_attention=is_cross_attention
|
| 124 |
+
self.alpha = config.alpha
|
| 125 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 126 |
+
if is_cross_attention:
|
| 127 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 128 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 129 |
+
long_term_attn_mechanism = partial(LongTermAttention,
|
| 130 |
+
attn_num_basis=config.num_basis,
|
| 131 |
+
head_size=self.attention_head_size,
|
| 132 |
+
length=config.encoder_width,
|
| 133 |
+
target_len=config.encoder_width,
|
| 134 |
+
attn_func="softmax",
|
| 135 |
+
infinite_memory=True,
|
| 136 |
+
n_layers=2,
|
| 137 |
+
attn_drop=0.1,
|
| 138 |
+
n_heads=self.num_attention_heads,
|
| 139 |
+
d_model=self.all_head_size,
|
| 140 |
+
affines=True,
|
| 141 |
+
mask=True,
|
| 142 |
+
mask_type="cnn",
|
| 143 |
+
kl_regularizer=False,
|
| 144 |
+
sigma_0=None,
|
| 145 |
+
mu_0=None,
|
| 146 |
+
sticky_memories=config.sticky,
|
| 147 |
+
continuous=True,
|
| 148 |
+
sigmas = 1,
|
| 149 |
+
tau = config.tau,
|
| 150 |
+
proj_key=self.key,
|
| 151 |
+
proj_value=self.value
|
| 152 |
+
)
|
| 153 |
+
self.long_term_attention=long_term_attn_mechanism()
|
| 154 |
+
if not is_cross_attention:
|
| 155 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 156 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 157 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 158 |
+
self.position_embedding_type = getattr(
|
| 159 |
+
config, "position_embedding_type", "absolute"
|
| 160 |
+
)
|
| 161 |
+
if (
|
| 162 |
+
self.position_embedding_type == "relative_key"
|
| 163 |
+
or self.position_embedding_type == "relative_key_query"
|
| 164 |
+
):
|
| 165 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 166 |
+
self.distance_embedding = nn.Embedding(
|
| 167 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
| 168 |
+
)
|
| 169 |
+
self.save_attention = False
|
| 170 |
+
|
| 171 |
+
def save_attn_gradients(self, attn_gradients):
|
| 172 |
+
self.attn_gradients = attn_gradients
|
| 173 |
+
|
| 174 |
+
def get_attn_gradients(self):
|
| 175 |
+
return self.attn_gradients
|
| 176 |
+
|
| 177 |
+
def save_attention_map(self, attention_map):
|
| 178 |
+
self.attention_map = attention_map
|
| 179 |
+
|
| 180 |
+
def get_attention_map(self):
|
| 181 |
+
return self.attention_map
|
| 182 |
+
|
| 183 |
+
def transpose_for_scores(self, x):
|
| 184 |
+
new_x_shape = x.size()[:-1] + (
|
| 185 |
+
self.num_attention_heads,
|
| 186 |
+
self.attention_head_size,
|
| 187 |
+
)
|
| 188 |
+
x = x.view(*new_x_shape)
|
| 189 |
+
return x.permute(0, 2, 1, 3)
|
| 190 |
+
|
| 191 |
+
def forward(
|
| 192 |
+
self,
|
| 193 |
+
hidden_states,
|
| 194 |
+
position_embedding_ext,
|
| 195 |
+
layer,
|
| 196 |
+
attention_mask=None,
|
| 197 |
+
head_mask=None,
|
| 198 |
+
encoder_hidden_states=None,
|
| 199 |
+
encoder_attention_mask=None,
|
| 200 |
+
past_key_value=None,
|
| 201 |
+
output_attentions=False,
|
| 202 |
+
new_video=False,
|
| 203 |
+
):
|
| 204 |
+
|
| 205 |
+
mixed_query_layer = self.query(hidden_states) #[1, 32, 768]
|
| 206 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 207 |
+
# and values come from an encoder; the attention mask needs to be
|
| 208 |
+
# such that the encoder's padding tokens are not attended to.
|
| 209 |
+
is_cross_attention = self.is_cross_attention
|
| 210 |
+
if is_cross_attention:
|
| 211 |
+
bsz, p, h = encoder_hidden_states.shape
|
| 212 |
+
self.long_term_attention.length = p
|
| 213 |
+
self.long_term_attention.target_len = p
|
| 214 |
+
if self.alpha != 1.0:
|
| 215 |
+
a_long_term = self.long_term_attention(encoder_hidden_states, mixed_query_layer, new_doc=new_video, layer_n=layer).detach()
|
| 216 |
+
else:
|
| 217 |
+
a_long_term = 0
|
| 218 |
+
if is_cross_attention:
|
| 219 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 220 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 221 |
+
attention_mask = encoder_attention_mask
|
| 222 |
+
elif past_key_value is not None:
|
| 223 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 224 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 225 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 226 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 227 |
+
else:
|
| 228 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 229 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)#[1,12,32,64]
|
| 233 |
+
past_key_value = (key_layer, value_layer)
|
| 234 |
+
|
| 235 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 236 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 237 |
+
|
| 238 |
+
if (
|
| 239 |
+
self.position_embedding_type == "relative_key"
|
| 240 |
+
or self.position_embedding_type == "relative_key_query"
|
| 241 |
+
):
|
| 242 |
+
seq_length = hidden_states.size()[1]
|
| 243 |
+
position_ids_l = torch.arange(
|
| 244 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
| 245 |
+
).view(-1, 1)
|
| 246 |
+
position_ids_r = torch.arange(
|
| 247 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
| 248 |
+
).view(1, -1)
|
| 249 |
+
distance = position_ids_l - position_ids_r
|
| 250 |
+
positional_embedding = self.distance_embedding(
|
| 251 |
+
distance + self.max_position_embeddings - 1
|
| 252 |
+
)
|
| 253 |
+
positional_embedding = positional_embedding.to(
|
| 254 |
+
dtype=query_layer.dtype
|
| 255 |
+
) # fp16 compatibility
|
| 256 |
+
|
| 257 |
+
if self.position_embedding_type == "relative_key":
|
| 258 |
+
relative_position_scores = torch.einsum(
|
| 259 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
| 260 |
+
)
|
| 261 |
+
attention_scores = attention_scores + relative_position_scores
|
| 262 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 263 |
+
relative_position_scores_query = torch.einsum(
|
| 264 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
| 265 |
+
)
|
| 266 |
+
relative_position_scores_key = torch.einsum(
|
| 267 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
| 268 |
+
)
|
| 269 |
+
attention_scores = (
|
| 270 |
+
attention_scores
|
| 271 |
+
+ relative_position_scores_query
|
| 272 |
+
+ relative_position_scores_key
|
| 273 |
+
)
|
| 274 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 275 |
+
if attention_mask is not None:
|
| 276 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 277 |
+
attention_scores = attention_scores + attention_mask
|
| 278 |
+
|
| 279 |
+
# Normalize the attention scores to probabilities.
|
| 280 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 281 |
+
|
| 282 |
+
if is_cross_attention and self.save_attention:
|
| 283 |
+
self.save_attention_map(attention_probs)
|
| 284 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 285 |
+
|
| 286 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 287 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 288 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 289 |
+
# Mask heads if we want to
|
| 290 |
+
if head_mask is not None:
|
| 291 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 292 |
+
|
| 293 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer) #[1, 12, 32, 64]
|
| 294 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 295 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 296 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 297 |
+
if is_cross_attention:
|
| 298 |
+
context_layer = self.alpha*context_layer + (1-self.alpha)*a_long_term
|
| 299 |
+
outputs = (
|
| 300 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
outputs = outputs + (past_key_value,)
|
| 304 |
+
return outputs
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class BertSelfOutput(nn.Module):
|
| 308 |
+
def __init__(self, config):
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 311 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 312 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 313 |
+
|
| 314 |
+
def forward(self, hidden_states, input_tensor):
|
| 315 |
+
hidden_states = self.dense(hidden_states)
|
| 316 |
+
hidden_states = self.dropout(hidden_states)
|
| 317 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 318 |
+
return hidden_states
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class BertAttention(nn.Module):
|
| 322 |
+
def __init__(self, config, is_cross_attention=False):
|
| 323 |
+
super().__init__()
|
| 324 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 325 |
+
self.output = BertSelfOutput(config)
|
| 326 |
+
self.pruned_heads = set()
|
| 327 |
+
|
| 328 |
+
def prune_heads(self, heads):
|
| 329 |
+
if len(heads) == 0:
|
| 330 |
+
return
|
| 331 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 332 |
+
heads,
|
| 333 |
+
self.self.num_attention_heads,
|
| 334 |
+
self.self.attention_head_size,
|
| 335 |
+
self.pruned_heads,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Prune linear layers
|
| 339 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 340 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 341 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 342 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 343 |
+
|
| 344 |
+
# Update hyper params and store pruned heads
|
| 345 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 346 |
+
self.self.all_head_size = (
|
| 347 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
| 348 |
+
)
|
| 349 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 350 |
+
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states,
|
| 354 |
+
position_embedding_ext,
|
| 355 |
+
layer,
|
| 356 |
+
attention_mask=None,
|
| 357 |
+
head_mask=None,
|
| 358 |
+
encoder_hidden_states=None,
|
| 359 |
+
encoder_attention_mask=None,
|
| 360 |
+
past_key_value=None,
|
| 361 |
+
output_attentions=False,
|
| 362 |
+
new_video=False,
|
| 363 |
+
):
|
| 364 |
+
self_outputs = self.self(
|
| 365 |
+
hidden_states,
|
| 366 |
+
position_embedding_ext,
|
| 367 |
+
layer,
|
| 368 |
+
attention_mask,
|
| 369 |
+
head_mask,
|
| 370 |
+
encoder_hidden_states,
|
| 371 |
+
encoder_attention_mask,
|
| 372 |
+
past_key_value,
|
| 373 |
+
output_attentions,
|
| 374 |
+
new_video=new_video,
|
| 375 |
+
)
|
| 376 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 377 |
+
|
| 378 |
+
outputs = (attention_output,) + self_outputs[
|
| 379 |
+
1:
|
| 380 |
+
] # add attentions if we output them
|
| 381 |
+
return outputs
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class BertIntermediate(nn.Module):
|
| 385 |
+
def __init__(self, config):
|
| 386 |
+
super().__init__()
|
| 387 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 388 |
+
if isinstance(config.hidden_act, str):
|
| 389 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 390 |
+
else:
|
| 391 |
+
self.intermediate_act_fn = config.hidden_act
|
| 392 |
+
|
| 393 |
+
def forward(self, hidden_states):
|
| 394 |
+
hidden_states = self.dense(hidden_states)
|
| 395 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 396 |
+
return hidden_states
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class BertOutput(nn.Module):
|
| 400 |
+
def __init__(self, config):
|
| 401 |
+
super().__init__()
|
| 402 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 403 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 404 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 405 |
+
|
| 406 |
+
def forward(self, hidden_states, input_tensor):
|
| 407 |
+
hidden_states = self.dense(hidden_states)
|
| 408 |
+
hidden_states = self.dropout(hidden_states)
|
| 409 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 410 |
+
return hidden_states
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class BertLayer(nn.Module):
|
| 414 |
+
def __init__(self, config, layer_num):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.config = config
|
| 417 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 418 |
+
self.seq_len_dim = 1
|
| 419 |
+
self.attention = BertAttention(config)
|
| 420 |
+
self.layer_num = layer_num
|
| 421 |
+
if (
|
| 422 |
+
self.config.add_cross_attention
|
| 423 |
+
and layer_num % self.config.cross_attention_freq == 0
|
| 424 |
+
):
|
| 425 |
+
self.crossattention = BertAttention(
|
| 426 |
+
config, is_cross_attention=self.config.add_cross_attention
|
| 427 |
+
)
|
| 428 |
+
self.has_cross_attention = True
|
| 429 |
+
else:
|
| 430 |
+
self.has_cross_attention = False
|
| 431 |
+
self.intermediate = BertIntermediate(config)
|
| 432 |
+
self.output = BertOutput(config)
|
| 433 |
+
|
| 434 |
+
self.intermediate_query = BertIntermediate(config)
|
| 435 |
+
self.output_query = BertOutput(config)
|
| 436 |
+
|
| 437 |
+
def forward(
|
| 438 |
+
self,
|
| 439 |
+
hidden_states,
|
| 440 |
+
position_embedding_ext,
|
| 441 |
+
layer,
|
| 442 |
+
attention_mask=None,
|
| 443 |
+
head_mask=None,
|
| 444 |
+
encoder_hidden_states=None,
|
| 445 |
+
encoder_attention_mask=None,
|
| 446 |
+
past_key_value=None,
|
| 447 |
+
output_attentions=False,
|
| 448 |
+
query_length=0,
|
| 449 |
+
new_video=False,
|
| 450 |
+
):
|
| 451 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 452 |
+
self_attn_past_key_value = (
|
| 453 |
+
past_key_value[:2] if past_key_value is not None else None
|
| 454 |
+
)
|
| 455 |
+
self_attention_outputs = self.attention(
|
| 456 |
+
hidden_states,
|
| 457 |
+
position_embedding_ext,
|
| 458 |
+
layer,
|
| 459 |
+
attention_mask,
|
| 460 |
+
head_mask,
|
| 461 |
+
output_attentions=output_attentions,
|
| 462 |
+
past_key_value=self_attn_past_key_value,
|
| 463 |
+
new_video=new_video,
|
| 464 |
+
)
|
| 465 |
+
attention_output = self_attention_outputs[0]
|
| 466 |
+
outputs = self_attention_outputs[1:-1]
|
| 467 |
+
|
| 468 |
+
present_key_value = self_attention_outputs[-1]
|
| 469 |
+
|
| 470 |
+
if query_length > 0:
|
| 471 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 472 |
+
|
| 473 |
+
if self.has_cross_attention:
|
| 474 |
+
assert (
|
| 475 |
+
encoder_hidden_states is not None
|
| 476 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
| 477 |
+
cross_attention_outputs = self.crossattention(
|
| 478 |
+
query_attention_output,
|
| 479 |
+
position_embedding_ext,
|
| 480 |
+
layer,
|
| 481 |
+
attention_mask,
|
| 482 |
+
head_mask,
|
| 483 |
+
encoder_hidden_states,
|
| 484 |
+
encoder_attention_mask,
|
| 485 |
+
output_attentions=output_attentions,
|
| 486 |
+
new_video=new_video,
|
| 487 |
+
)
|
| 488 |
+
query_attention_output = cross_attention_outputs[0]
|
| 489 |
+
outputs = (
|
| 490 |
+
outputs + cross_attention_outputs[1:-1]
|
| 491 |
+
) # add cross attentions if we output attention weights
|
| 492 |
+
|
| 493 |
+
layer_output = apply_chunking_to_forward(
|
| 494 |
+
self.feed_forward_chunk_query,
|
| 495 |
+
self.chunk_size_feed_forward,
|
| 496 |
+
self.seq_len_dim,
|
| 497 |
+
query_attention_output,
|
| 498 |
+
)
|
| 499 |
+
if attention_output.shape[1] > query_length:
|
| 500 |
+
layer_output_text = apply_chunking_to_forward(
|
| 501 |
+
self.feed_forward_chunk,
|
| 502 |
+
self.chunk_size_feed_forward,
|
| 503 |
+
self.seq_len_dim,
|
| 504 |
+
attention_output[:, query_length:, :],
|
| 505 |
+
)
|
| 506 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 507 |
+
else:
|
| 508 |
+
layer_output = apply_chunking_to_forward(
|
| 509 |
+
self.feed_forward_chunk,
|
| 510 |
+
self.chunk_size_feed_forward,
|
| 511 |
+
self.seq_len_dim,
|
| 512 |
+
attention_output,
|
| 513 |
+
)
|
| 514 |
+
outputs = (layer_output,) + outputs
|
| 515 |
+
|
| 516 |
+
outputs = outputs + (present_key_value,)
|
| 517 |
+
|
| 518 |
+
return outputs
|
| 519 |
+
|
| 520 |
+
def feed_forward_chunk(self, attention_output):
|
| 521 |
+
intermediate_output = self.intermediate(attention_output)
|
| 522 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 523 |
+
return layer_output
|
| 524 |
+
|
| 525 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 526 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 527 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 528 |
+
return layer_output
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class BertEncoder(nn.Module):
|
| 532 |
+
def __init__(self, config):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.config = config
|
| 535 |
+
self.layer = nn.ModuleList(
|
| 536 |
+
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def forward(
|
| 540 |
+
self,
|
| 541 |
+
hidden_states,
|
| 542 |
+
position_embedding_ext,
|
| 543 |
+
attention_mask=None,
|
| 544 |
+
head_mask=None,
|
| 545 |
+
encoder_hidden_states=None,
|
| 546 |
+
encoder_attention_mask=None,
|
| 547 |
+
past_key_values=None,
|
| 548 |
+
use_cache=None,
|
| 549 |
+
output_attentions=False,
|
| 550 |
+
output_hidden_states=False,
|
| 551 |
+
return_dict=True,
|
| 552 |
+
query_length=0,
|
| 553 |
+
new_video=False,
|
| 554 |
+
):
|
| 555 |
+
all_hidden_states = () if output_hidden_states else None
|
| 556 |
+
all_self_attentions = () if output_attentions else None
|
| 557 |
+
all_cross_attentions = (
|
| 558 |
+
() if output_attentions and self.config.add_cross_attention else None
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
next_decoder_cache = () if use_cache else None
|
| 562 |
+
|
| 563 |
+
for i in range(self.config.num_hidden_layers):
|
| 564 |
+
layer_module = self.layer[i]
|
| 565 |
+
if output_hidden_states:
|
| 566 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 567 |
+
|
| 568 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 569 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 570 |
+
|
| 571 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
| 572 |
+
|
| 573 |
+
if use_cache:
|
| 574 |
+
logger.warn(
|
| 575 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 576 |
+
)
|
| 577 |
+
use_cache = False
|
| 578 |
+
|
| 579 |
+
def create_custom_forward(module):
|
| 580 |
+
def custom_forward(*inputs):
|
| 581 |
+
return module(
|
| 582 |
+
*inputs, past_key_value, output_attentions, query_length
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
return custom_forward
|
| 586 |
+
|
| 587 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 588 |
+
create_custom_forward(layer_module),
|
| 589 |
+
hidden_states,
|
| 590 |
+
position_embedding_ext,
|
| 591 |
+
i,
|
| 592 |
+
attention_mask,
|
| 593 |
+
layer_head_mask,
|
| 594 |
+
encoder_hidden_states,
|
| 595 |
+
encoder_attention_mask,
|
| 596 |
+
new_video=new_video
|
| 597 |
+
)
|
| 598 |
+
else:
|
| 599 |
+
layer_outputs = layer_module(
|
| 600 |
+
hidden_states,
|
| 601 |
+
position_embedding_ext,
|
| 602 |
+
i,
|
| 603 |
+
attention_mask,
|
| 604 |
+
layer_head_mask,
|
| 605 |
+
encoder_hidden_states,
|
| 606 |
+
encoder_attention_mask,
|
| 607 |
+
past_key_value,
|
| 608 |
+
output_attentions,
|
| 609 |
+
query_length,
|
| 610 |
+
new_video=new_video,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
hidden_states = layer_outputs[0]
|
| 614 |
+
if use_cache:
|
| 615 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 616 |
+
if output_attentions:
|
| 617 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 618 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 619 |
+
|
| 620 |
+
if output_hidden_states:
|
| 621 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 622 |
+
|
| 623 |
+
if not return_dict:
|
| 624 |
+
return tuple(
|
| 625 |
+
v
|
| 626 |
+
for v in [
|
| 627 |
+
hidden_states,
|
| 628 |
+
next_decoder_cache,
|
| 629 |
+
all_hidden_states,
|
| 630 |
+
all_self_attentions,
|
| 631 |
+
all_cross_attentions,
|
| 632 |
+
]
|
| 633 |
+
if v is not None
|
| 634 |
+
)
|
| 635 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 636 |
+
last_hidden_state=hidden_states,
|
| 637 |
+
past_key_values=next_decoder_cache,
|
| 638 |
+
hidden_states=all_hidden_states,
|
| 639 |
+
attentions=all_self_attentions,
|
| 640 |
+
cross_attentions=all_cross_attentions,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
class BertPooler(nn.Module):
|
| 645 |
+
def __init__(self, config):
|
| 646 |
+
super().__init__()
|
| 647 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 648 |
+
self.activation = nn.Tanh()
|
| 649 |
+
|
| 650 |
+
def forward(self, hidden_states):
|
| 651 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 652 |
+
# to the first token.
|
| 653 |
+
first_token_tensor = hidden_states[:, 0]
|
| 654 |
+
pooled_output = self.dense(first_token_tensor)
|
| 655 |
+
pooled_output = self.activation(pooled_output)
|
| 656 |
+
return pooled_output
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 660 |
+
def __init__(self, config):
|
| 661 |
+
super().__init__()
|
| 662 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 663 |
+
if isinstance(config.hidden_act, str):
|
| 664 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 665 |
+
else:
|
| 666 |
+
self.transform_act_fn = config.hidden_act
|
| 667 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 668 |
+
|
| 669 |
+
def forward(self, hidden_states):
|
| 670 |
+
hidden_states = self.dense(hidden_states)
|
| 671 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 672 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 673 |
+
return hidden_states
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class BertLMPredictionHead(nn.Module):
|
| 677 |
+
def __init__(self, config):
|
| 678 |
+
super().__init__()
|
| 679 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 680 |
+
|
| 681 |
+
# The output weights are the same as the input embeddings, but there is
|
| 682 |
+
# an output-only bias for each token.
|
| 683 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 684 |
+
|
| 685 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 686 |
+
|
| 687 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 688 |
+
self.decoder.bias = self.bias
|
| 689 |
+
|
| 690 |
+
def forward(self, hidden_states):
|
| 691 |
+
hidden_states = self.transform(hidden_states)
|
| 692 |
+
hidden_states = self.decoder(hidden_states)
|
| 693 |
+
return hidden_states
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class BertOnlyMLMHead(nn.Module):
|
| 697 |
+
def __init__(self, config):
|
| 698 |
+
super().__init__()
|
| 699 |
+
self.predictions = BertLMPredictionHead(config)
|
| 700 |
+
|
| 701 |
+
def forward(self, sequence_output):
|
| 702 |
+
prediction_scores = self.predictions(sequence_output)
|
| 703 |
+
return prediction_scores
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 707 |
+
"""
|
| 708 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 709 |
+
models.
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
config_class = BertConfig
|
| 713 |
+
base_model_prefix = "bert"
|
| 714 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 715 |
+
|
| 716 |
+
def _init_weights(self, module):
|
| 717 |
+
"""Initialize the weights"""
|
| 718 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 719 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 720 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 721 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 722 |
+
elif isinstance(module, nn.LayerNorm):
|
| 723 |
+
module.bias.data.zero_()
|
| 724 |
+
module.weight.data.fill_(1.0)
|
| 725 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 726 |
+
module.bias.data.zero_()
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class BertModel(BertPreTrainedModel):
|
| 730 |
+
"""
|
| 731 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 732 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 733 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 734 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 735 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 736 |
+
input to the forward pass.
|
| 737 |
+
"""
|
| 738 |
+
|
| 739 |
+
def __init__(self, config, add_pooling_layer=False):
|
| 740 |
+
super().__init__(config)
|
| 741 |
+
self.config = config
|
| 742 |
+
|
| 743 |
+
self.embeddings = BertEmbeddings(config)
|
| 744 |
+
|
| 745 |
+
self.encoder = BertEncoder(config)
|
| 746 |
+
|
| 747 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 748 |
+
|
| 749 |
+
self.init_weights()
|
| 750 |
+
|
| 751 |
+
def get_input_embeddings(self):
|
| 752 |
+
return self.embeddings.word_embeddings
|
| 753 |
+
|
| 754 |
+
def set_input_embeddings(self, value):
|
| 755 |
+
self.embeddings.word_embeddings = value
|
| 756 |
+
|
| 757 |
+
def _prune_heads(self, heads_to_prune):
|
| 758 |
+
"""
|
| 759 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 760 |
+
class PreTrainedModel
|
| 761 |
+
"""
|
| 762 |
+
for layer, heads in heads_to_prune.items():
|
| 763 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 764 |
+
|
| 765 |
+
def get_extended_attention_mask(
|
| 766 |
+
self,
|
| 767 |
+
attention_mask: Tensor,
|
| 768 |
+
input_shape: Tuple[int],
|
| 769 |
+
device: device,
|
| 770 |
+
is_decoder: bool,
|
| 771 |
+
has_query: bool = False,
|
| 772 |
+
) -> Tensor:
|
| 773 |
+
"""
|
| 774 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 775 |
+
|
| 776 |
+
Arguments:
|
| 777 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 778 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 779 |
+
input_shape (:obj:`Tuple[int]`):
|
| 780 |
+
The shape of the input to the model.
|
| 781 |
+
device: (:obj:`torch.device`):
|
| 782 |
+
The device of the input to the model.
|
| 783 |
+
|
| 784 |
+
Returns:
|
| 785 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 786 |
+
"""
|
| 787 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 788 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 789 |
+
if attention_mask.dim() == 3:
|
| 790 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 791 |
+
elif attention_mask.dim() == 2:
|
| 792 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 793 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 794 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 795 |
+
if is_decoder:
|
| 796 |
+
batch_size, seq_length = input_shape
|
| 797 |
+
|
| 798 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 799 |
+
causal_mask = (
|
| 800 |
+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
|
| 801 |
+
<= seq_ids[None, :, None]
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
# add a prefix ones mask to the causal mask
|
| 805 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 806 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 807 |
+
|
| 808 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 809 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 810 |
+
if has_query: # UniLM style attention mask
|
| 811 |
+
causal_mask = torch.cat(
|
| 812 |
+
[
|
| 813 |
+
torch.zeros(
|
| 814 |
+
(batch_size, prefix_seq_len, seq_length),
|
| 815 |
+
device=device,
|
| 816 |
+
dtype=causal_mask.dtype,
|
| 817 |
+
),
|
| 818 |
+
causal_mask,
|
| 819 |
+
],
|
| 820 |
+
axis=1,
|
| 821 |
+
)
|
| 822 |
+
causal_mask = torch.cat(
|
| 823 |
+
[
|
| 824 |
+
torch.ones(
|
| 825 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
| 826 |
+
device=device,
|
| 827 |
+
dtype=causal_mask.dtype,
|
| 828 |
+
),
|
| 829 |
+
causal_mask,
|
| 830 |
+
],
|
| 831 |
+
axis=-1,
|
| 832 |
+
)
|
| 833 |
+
extended_attention_mask = (
|
| 834 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 835 |
+
)
|
| 836 |
+
else:
|
| 837 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 838 |
+
else:
|
| 839 |
+
raise ValueError(
|
| 840 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 841 |
+
input_shape, attention_mask.shape
|
| 842 |
+
)
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 846 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 847 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 848 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 849 |
+
# effectively the same as removing these entirely.
|
| 850 |
+
extended_attention_mask = extended_attention_mask.to(
|
| 851 |
+
dtype=self.dtype
|
| 852 |
+
) # fp16 compatibility
|
| 853 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 854 |
+
return extended_attention_mask
|
| 855 |
+
|
| 856 |
+
def forward(
|
| 857 |
+
self,
|
| 858 |
+
input_ids=None,
|
| 859 |
+
position_embedding_ext=None,
|
| 860 |
+
attention_mask=None,
|
| 861 |
+
position_ids=None,
|
| 862 |
+
head_mask=None,
|
| 863 |
+
query_embeds=None,
|
| 864 |
+
encoder_hidden_states=None,
|
| 865 |
+
encoder_attention_mask=None,
|
| 866 |
+
past_key_values=None,
|
| 867 |
+
use_cache=None,
|
| 868 |
+
output_attentions=None,
|
| 869 |
+
output_hidden_states=None,
|
| 870 |
+
return_dict=None,
|
| 871 |
+
is_decoder=False,
|
| 872 |
+
new_video=False,
|
| 873 |
+
):
|
| 874 |
+
r"""
|
| 875 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 876 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 877 |
+
the model is configured as a decoder.
|
| 878 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 879 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 880 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 881 |
+
- 1 for tokens that are **not masked**,
|
| 882 |
+
- 0 for tokens that are **masked**.
|
| 883 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 884 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 885 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 886 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 887 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 888 |
+
use_cache (:obj:`bool`, `optional`):
|
| 889 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 890 |
+
decoding (see :obj:`past_key_values`).
|
| 891 |
+
"""
|
| 892 |
+
output_attentions = (
|
| 893 |
+
output_attentions
|
| 894 |
+
if output_attentions is not None
|
| 895 |
+
else self.config.output_attentions
|
| 896 |
+
)
|
| 897 |
+
output_hidden_states = (
|
| 898 |
+
output_hidden_states
|
| 899 |
+
if output_hidden_states is not None
|
| 900 |
+
else self.config.output_hidden_states
|
| 901 |
+
)
|
| 902 |
+
return_dict = (
|
| 903 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 907 |
+
|
| 908 |
+
if input_ids is None:
|
| 909 |
+
assert (
|
| 910 |
+
query_embeds is not None
|
| 911 |
+
), "You have to specify query_embeds when input_ids is None"
|
| 912 |
+
|
| 913 |
+
# past_key_values_length
|
| 914 |
+
past_key_values_length = (
|
| 915 |
+
past_key_values[0][0].shape[2] - self.config.query_length
|
| 916 |
+
if past_key_values is not None
|
| 917 |
+
else 0
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 921 |
+
|
| 922 |
+
embedding_output = self.embeddings(
|
| 923 |
+
input_ids=input_ids,
|
| 924 |
+
position_ids=position_ids,
|
| 925 |
+
query_embeds=query_embeds,
|
| 926 |
+
past_key_values_length=past_key_values_length,
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
input_shape = embedding_output.size()[:-1]
|
| 930 |
+
batch_size, seq_length = input_shape
|
| 931 |
+
device = embedding_output.device
|
| 932 |
+
|
| 933 |
+
if attention_mask is None:
|
| 934 |
+
attention_mask = torch.ones(
|
| 935 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 939 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 940 |
+
if is_decoder:
|
| 941 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
| 942 |
+
attention_mask,
|
| 943 |
+
input_ids.shape,
|
| 944 |
+
device,
|
| 945 |
+
is_decoder,
|
| 946 |
+
has_query=(query_embeds is not None),
|
| 947 |
+
)
|
| 948 |
+
else:
|
| 949 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
| 950 |
+
attention_mask, input_shape, device, is_decoder
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 954 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 955 |
+
if encoder_hidden_states is not None:
|
| 956 |
+
if type(encoder_hidden_states) == list:
|
| 957 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
| 958 |
+
0
|
| 959 |
+
].size()
|
| 960 |
+
else:
|
| 961 |
+
(
|
| 962 |
+
encoder_batch_size,
|
| 963 |
+
encoder_sequence_length,
|
| 964 |
+
_,
|
| 965 |
+
) = encoder_hidden_states.size()
|
| 966 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 967 |
+
|
| 968 |
+
if type(encoder_attention_mask) == list:
|
| 969 |
+
encoder_extended_attention_mask = [
|
| 970 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask
|
| 971 |
+
]
|
| 972 |
+
elif encoder_attention_mask is None:
|
| 973 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 974 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
| 975 |
+
encoder_attention_mask
|
| 976 |
+
)
|
| 977 |
+
else:
|
| 978 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
| 979 |
+
encoder_attention_mask
|
| 980 |
+
)
|
| 981 |
+
else:
|
| 982 |
+
encoder_extended_attention_mask = None
|
| 983 |
+
|
| 984 |
+
# Prepare head mask if needed
|
| 985 |
+
# 1.0 in head_mask indicate we keep the head
|
| 986 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 987 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 988 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 989 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 990 |
+
|
| 991 |
+
encoder_outputs = self.encoder(
|
| 992 |
+
embedding_output,
|
| 993 |
+
position_embedding_ext,
|
| 994 |
+
attention_mask=extended_attention_mask,
|
| 995 |
+
head_mask=head_mask,
|
| 996 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 997 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 998 |
+
past_key_values=past_key_values,
|
| 999 |
+
use_cache=use_cache,
|
| 1000 |
+
output_attentions=output_attentions,
|
| 1001 |
+
output_hidden_states=output_hidden_states,
|
| 1002 |
+
return_dict=return_dict,
|
| 1003 |
+
query_length=query_length,
|
| 1004 |
+
new_video = new_video
|
| 1005 |
+
)
|
| 1006 |
+
sequence_output = encoder_outputs[0]
|
| 1007 |
+
pooled_output = (
|
| 1008 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
if not return_dict:
|
| 1012 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1013 |
+
|
| 1014 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1015 |
+
last_hidden_state=sequence_output,
|
| 1016 |
+
pooler_output=pooled_output,
|
| 1017 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1018 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1019 |
+
attentions=encoder_outputs.attentions,
|
| 1020 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1021 |
+
)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 1025 |
+
|
| 1026 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 1027 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 1028 |
+
|
| 1029 |
+
def __init__(self, config):
|
| 1030 |
+
super().__init__(config)
|
| 1031 |
+
|
| 1032 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1033 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1034 |
+
|
| 1035 |
+
self.init_weights()
|
| 1036 |
+
|
| 1037 |
+
def get_output_embeddings(self):
|
| 1038 |
+
return self.cls.predictions.decoder
|
| 1039 |
+
|
| 1040 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1041 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1042 |
+
|
| 1043 |
+
def forward(
|
| 1044 |
+
self,
|
| 1045 |
+
input_ids=None,
|
| 1046 |
+
attention_mask=None,
|
| 1047 |
+
position_ids=None,
|
| 1048 |
+
head_mask=None,
|
| 1049 |
+
query_embeds=None,
|
| 1050 |
+
encoder_hidden_states=None,
|
| 1051 |
+
encoder_attention_mask=None,
|
| 1052 |
+
labels=None,
|
| 1053 |
+
past_key_values=None,
|
| 1054 |
+
use_cache=True,
|
| 1055 |
+
output_attentions=None,
|
| 1056 |
+
output_hidden_states=None,
|
| 1057 |
+
return_dict=None,
|
| 1058 |
+
return_logits=False,
|
| 1059 |
+
is_decoder=True,
|
| 1060 |
+
reduction="mean",
|
| 1061 |
+
):
|
| 1062 |
+
r"""
|
| 1063 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 1064 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 1065 |
+
the model is configured as a decoder.
|
| 1066 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 1067 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 1068 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 1069 |
+
- 1 for tokens that are **not masked**,
|
| 1070 |
+
- 0 for tokens that are **masked**.
|
| 1071 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 1072 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 1073 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 1074 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 1075 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 1076 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 1077 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 1078 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 1079 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 1080 |
+
use_cache (:obj:`bool`, `optional`):
|
| 1081 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 1082 |
+
decoding (see :obj:`past_key_values`).
|
| 1083 |
+
Returns:
|
| 1084 |
+
Example::
|
| 1085 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 1086 |
+
>>> import torch
|
| 1087 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 1088 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 1089 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 1090 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 1091 |
+
>>> outputs = model(**inputs)
|
| 1092 |
+
>>> prediction_logits = outputs.logits
|
| 1093 |
+
"""
|
| 1094 |
+
return_dict = (
|
| 1095 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1096 |
+
)
|
| 1097 |
+
if labels is not None:
|
| 1098 |
+
use_cache = False
|
| 1099 |
+
if past_key_values is not None:
|
| 1100 |
+
query_embeds = None
|
| 1101 |
+
|
| 1102 |
+
outputs = self.bert(
|
| 1103 |
+
input_ids,
|
| 1104 |
+
attention_mask=attention_mask,
|
| 1105 |
+
position_ids=position_ids,
|
| 1106 |
+
head_mask=head_mask,
|
| 1107 |
+
query_embeds=query_embeds,
|
| 1108 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1109 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1110 |
+
past_key_values=past_key_values,
|
| 1111 |
+
use_cache=use_cache,
|
| 1112 |
+
output_attentions=output_attentions,
|
| 1113 |
+
output_hidden_states=output_hidden_states,
|
| 1114 |
+
return_dict=return_dict,
|
| 1115 |
+
is_decoder=is_decoder,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
sequence_output = outputs[0]
|
| 1119 |
+
if query_embeds is not None:
|
| 1120 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
| 1121 |
+
|
| 1122 |
+
prediction_scores = self.cls(sequence_output)
|
| 1123 |
+
|
| 1124 |
+
if return_logits:
|
| 1125 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 1126 |
+
|
| 1127 |
+
lm_loss = None
|
| 1128 |
+
if labels is not None:
|
| 1129 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 1130 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 1131 |
+
labels = labels[:, 1:].contiguous()
|
| 1132 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 1133 |
+
lm_loss = loss_fct(
|
| 1134 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
| 1135 |
+
labels.view(-1),
|
| 1136 |
+
)
|
| 1137 |
+
if reduction == "none":
|
| 1138 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
| 1139 |
+
|
| 1140 |
+
if not return_dict:
|
| 1141 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1142 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 1143 |
+
|
| 1144 |
+
return CausalLMOutputWithCrossAttentions(
|
| 1145 |
+
loss=lm_loss,
|
| 1146 |
+
logits=prediction_scores,
|
| 1147 |
+
past_key_values=outputs.past_key_values,
|
| 1148 |
+
hidden_states=outputs.hidden_states,
|
| 1149 |
+
attentions=outputs.attentions,
|
| 1150 |
+
cross_attentions=outputs.cross_attentions,
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
def prepare_inputs_for_generation(
|
| 1154 |
+
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
|
| 1155 |
+
):
|
| 1156 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 1157 |
+
if attention_mask is None:
|
| 1158 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
| 1159 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
| 1160 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
| 1161 |
+
|
| 1162 |
+
# cut decoder_input_ids if past is used
|
| 1163 |
+
if past is not None:
|
| 1164 |
+
input_ids = input_ids[:, -1:]
|
| 1165 |
+
|
| 1166 |
+
return {
|
| 1167 |
+
"input_ids": input_ids,
|
| 1168 |
+
"query_embeds": query_embeds,
|
| 1169 |
+
"attention_mask": attention_mask,
|
| 1170 |
+
"past_key_values": past,
|
| 1171 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 1172 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 1173 |
+
"is_decoder": True,
|
| 1174 |
+
}
|
| 1175 |
+
|
| 1176 |
+
def _reorder_cache(self, past, beam_idx):
|
| 1177 |
+
reordered_past = ()
|
| 1178 |
+
for layer_past in past:
|
| 1179 |
+
reordered_past += (
|
| 1180 |
+
tuple(
|
| 1181 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
| 1182 |
+
),
|
| 1183 |
+
)
|
| 1184 |
+
return reordered_past
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 1188 |
+
|
| 1189 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 1190 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 1191 |
+
|
| 1192 |
+
def __init__(self, config):
|
| 1193 |
+
super().__init__(config)
|
| 1194 |
+
|
| 1195 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1196 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1197 |
+
|
| 1198 |
+
self.init_weights()
|
| 1199 |
+
|
| 1200 |
+
def get_output_embeddings(self):
|
| 1201 |
+
return self.cls.predictions.decoder
|
| 1202 |
+
|
| 1203 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1204 |
+
self.cls.predictions.decoder = new_embeddings
|
| 1205 |
+
|
| 1206 |
+
def forward(
|
| 1207 |
+
self,
|
| 1208 |
+
input_ids=None,
|
| 1209 |
+
attention_mask=None,
|
| 1210 |
+
position_ids=None,
|
| 1211 |
+
head_mask=None,
|
| 1212 |
+
query_embeds=None,
|
| 1213 |
+
encoder_hidden_states=None,
|
| 1214 |
+
encoder_attention_mask=None,
|
| 1215 |
+
labels=None,
|
| 1216 |
+
output_attentions=None,
|
| 1217 |
+
output_hidden_states=None,
|
| 1218 |
+
return_dict=None,
|
| 1219 |
+
return_logits=False,
|
| 1220 |
+
is_decoder=False,
|
| 1221 |
+
):
|
| 1222 |
+
r"""
|
| 1223 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 1224 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
| 1225 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
| 1226 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
| 1227 |
+
"""
|
| 1228 |
+
|
| 1229 |
+
return_dict = (
|
| 1230 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
outputs = self.bert(
|
| 1234 |
+
input_ids,
|
| 1235 |
+
attention_mask=attention_mask,
|
| 1236 |
+
position_ids=position_ids,
|
| 1237 |
+
head_mask=head_mask,
|
| 1238 |
+
query_embeds=query_embeds,
|
| 1239 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1240 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1241 |
+
output_attentions=output_attentions,
|
| 1242 |
+
output_hidden_states=output_hidden_states,
|
| 1243 |
+
return_dict=return_dict,
|
| 1244 |
+
is_decoder=is_decoder,
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
if query_embeds is not None:
|
| 1248 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
| 1249 |
+
prediction_scores = self.cls(sequence_output)
|
| 1250 |
+
|
| 1251 |
+
if return_logits:
|
| 1252 |
+
return prediction_scores
|
| 1253 |
+
|
| 1254 |
+
masked_lm_loss = None
|
| 1255 |
+
if labels is not None:
|
| 1256 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1257 |
+
masked_lm_loss = loss_fct(
|
| 1258 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
if not return_dict:
|
| 1262 |
+
output = (prediction_scores,) + outputs[2:]
|
| 1263 |
+
return (
|
| 1264 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
return MaskedLMOutput(
|
| 1268 |
+
loss=masked_lm_loss,
|
| 1269 |
+
logits=prediction_scores,
|
| 1270 |
+
hidden_states=outputs.hidden_states,
|
| 1271 |
+
attentions=outputs.attentions,
|
| 1272 |
+
)
|
__init__.py
ADDED
|
File without changes
|
blip2.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright (c) 2023, salesforce.com, inc.
|
| 3 |
+
All rights reserved.
|
| 4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
"""
|
| 7 |
+
import contextlib
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from .Qformer import BertConfig, BertLMHeadModel
|
| 15 |
+
from .vit import build_vit
|
| 16 |
+
from transformers import BertTokenizer
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig
|
| 21 |
+
# class Blip2Base(nn.Module):
|
| 22 |
+
class Blip2Base(PreTrainedModel):
|
| 23 |
+
def __init__(self, config={}):
|
| 24 |
+
cfg=PretrainedConfig()
|
| 25 |
+
if isinstance(config,(PretrainedConfig,AutoConfig)):
|
| 26 |
+
cfg.update(config.to_dict())
|
| 27 |
+
else:
|
| 28 |
+
cfg.update(dict(config))
|
| 29 |
+
super().__init__(cfg)
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def init_tokenizer(cls, truncation_side="right"):
|
| 33 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side, local_files_only=True)
|
| 34 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
| 35 |
+
return tokenizer
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def device(self):
|
| 39 |
+
return list(self.parameters())[0].device
|
| 40 |
+
|
| 41 |
+
def maybe_autocast(self, dtype=torch.float16):
|
| 42 |
+
# if on cpu, don't use autocast
|
| 43 |
+
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
|
| 44 |
+
enable_autocast = self.device != torch.device("cpu")
|
| 45 |
+
|
| 46 |
+
if enable_autocast:
|
| 47 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
| 48 |
+
else:
|
| 49 |
+
return contextlib.nullcontext()
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def init_Qformer(
|
| 53 |
+
cls,
|
| 54 |
+
num_query_token, vision_width,
|
| 55 |
+
qformer_hidden_dropout_prob=0.1,
|
| 56 |
+
qformer_attention_probs_dropout_prob=0.1,
|
| 57 |
+
qformer_drop_path_rate=0.,
|
| 58 |
+
):
|
| 59 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased", local_files_only=True)
|
| 60 |
+
encoder_config.encoder_width = vision_width
|
| 61 |
+
# insert cross-attention layer every other block
|
| 62 |
+
encoder_config.add_cross_attention = True
|
| 63 |
+
encoder_config.cross_attention_freq = 2
|
| 64 |
+
encoder_config.query_length = num_query_token
|
| 65 |
+
encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob
|
| 66 |
+
encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
|
| 67 |
+
encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)]
|
| 68 |
+
logger.info(f"Drop_path:{encoder_config.drop_path_list}")
|
| 69 |
+
logger.info(encoder_config)
|
| 70 |
+
Qformer = BertLMHeadModel(config=encoder_config)
|
| 71 |
+
query_tokens = nn.Parameter(
|
| 72 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
| 73 |
+
)
|
| 74 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
| 75 |
+
return Qformer, query_tokens
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def init_vision_encoder_umt(self, config):
|
| 79 |
+
"""build vision encoder
|
| 80 |
+
Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
vision_encoder = build_vit(config)
|
| 84 |
+
|
| 85 |
+
if config.vision_encoder.vit_add_ln:
|
| 86 |
+
vision_layernorm = nn.LayerNorm(config.vision_encoder.encoder_embed_dim, eps=1e-12)
|
| 87 |
+
else:
|
| 88 |
+
vision_layernorm = nn.Identity()
|
| 89 |
+
|
| 90 |
+
return vision_encoder, vision_layernorm
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def disabled_train(self, mode=True):
|
| 94 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 95 |
+
does not change anymore."""
|
| 96 |
+
return self
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class LayerNorm(nn.LayerNorm):
|
| 100 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor):
|
| 103 |
+
orig_type = x.dtype
|
| 104 |
+
ret = super().forward(x.type(torch.float32))
|
| 105 |
+
return ret.type(orig_type)
|
config.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_second_msg": true,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"VideoChat2_it_hd_mistral"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_videochat2.Config",
|
| 8 |
+
"AutoModel": "videochat2_it_hd_mistral.VideoChat2_it_hd_mistral"
|
| 9 |
+
},
|
| 10 |
+
"dynamic_config": {
|
| 11 |
+
"add_global": true,
|
| 12 |
+
"hd_num": 6,
|
| 13 |
+
"local_size": 224,
|
| 14 |
+
"padding": false
|
| 15 |
+
},
|
| 16 |
+
"end_token": "</Video>",
|
| 17 |
+
"extra_num_query_token": 64,
|
| 18 |
+
"freeze_qformer": false,
|
| 19 |
+
"freeze_vit": false,
|
| 20 |
+
"img_end_token": "</Image>",
|
| 21 |
+
"img_start_token": "<Image>",
|
| 22 |
+
"lora_alpha": 32,
|
| 23 |
+
"lora_dropout": 0.1,
|
| 24 |
+
"lora_r": 16,
|
| 25 |
+
"low_resource": false,
|
| 26 |
+
"max_txt_len": 512,
|
| 27 |
+
"mistral_model_path": "mistralai/Mistral-7B-Instruct-v0.2",
|
| 28 |
+
"model_cls": "VideoChat2_it_hd_mistral",
|
| 29 |
+
"num_query_token": 32,
|
| 30 |
+
"qformer_attention_probs_dropout_prob": 0.1,
|
| 31 |
+
"qformer_drop_path_rate": 0.2,
|
| 32 |
+
"qformer_hidden_dropout_prob": 0.1,
|
| 33 |
+
"qformer_text_input": true,
|
| 34 |
+
"random_shuffle": true,
|
| 35 |
+
"return_question_instruction": false,
|
| 36 |
+
"start_token": "<Video>",
|
| 37 |
+
"system": "",
|
| 38 |
+
"torch_dtype": "float32",
|
| 39 |
+
"transformers_version": "4.44.2",
|
| 40 |
+
"use_flash_attention": false,
|
| 41 |
+
"use_lora": false,
|
| 42 |
+
"videochat2_model_path": "",
|
| 43 |
+
"vision_encoder": {
|
| 44 |
+
"checkpoint_num": 18,
|
| 45 |
+
"ckpt_num_frame": 4,
|
| 46 |
+
"d_model": 1024,
|
| 47 |
+
"drop_path_rate": 0.0,
|
| 48 |
+
"encoder_depth": 24,
|
| 49 |
+
"encoder_embed_dim": 1024,
|
| 50 |
+
"encoder_num_heads": 16,
|
| 51 |
+
"img_size": 224,
|
| 52 |
+
"name": "vit_l14",
|
| 53 |
+
"num_frames": 4,
|
| 54 |
+
"patch_size": 16,
|
| 55 |
+
"pretrained": "",
|
| 56 |
+
"return_index": -2,
|
| 57 |
+
"tubelet_size": 1,
|
| 58 |
+
"use_checkpoint": true,
|
| 59 |
+
"vit_add_ln": true
|
| 60 |
+
},
|
| 61 |
+
"vit_blip_model_path": ""
|
| 62 |
+
}
|
configuration_videochat2.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import ast
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import re
|
| 9 |
+
import shutil
|
| 10 |
+
import sys
|
| 11 |
+
import tempfile
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from importlib import import_module
|
| 14 |
+
|
| 15 |
+
import yaml
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = ["Config", "pretty_text"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
BASE_KEY = "_base_"
|
| 22 |
+
# BASE_CONFIG = {"OUTPUT_DIR": "./workspace", "SESSION": "base", "LOG_FILE": "log.txt"}
|
| 23 |
+
BASE_CONFIG = {}
|
| 24 |
+
|
| 25 |
+
cfg = None
|
| 26 |
+
|
| 27 |
+
class EasyDict(dict):
|
| 28 |
+
"""
|
| 29 |
+
Get attributes
|
| 30 |
+
|
| 31 |
+
>>> d = EasyDict({'foo':3})
|
| 32 |
+
>>> d['foo']
|
| 33 |
+
3
|
| 34 |
+
>>> d.foo
|
| 35 |
+
3
|
| 36 |
+
>>> d.bar
|
| 37 |
+
Traceback (most recent call last):
|
| 38 |
+
...
|
| 39 |
+
AttributeError: 'EasyDict' object has no attribute 'bar'
|
| 40 |
+
|
| 41 |
+
Works recursively
|
| 42 |
+
|
| 43 |
+
>>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
|
| 44 |
+
>>> isinstance(d.bar, dict)
|
| 45 |
+
True
|
| 46 |
+
>>> d.bar.x
|
| 47 |
+
1
|
| 48 |
+
|
| 49 |
+
Bullet-proof
|
| 50 |
+
|
| 51 |
+
>>> EasyDict({})
|
| 52 |
+
{}
|
| 53 |
+
>>> EasyDict(d={})
|
| 54 |
+
{}
|
| 55 |
+
>>> EasyDict(None)
|
| 56 |
+
{}
|
| 57 |
+
>>> d = {'a': 1}
|
| 58 |
+
>>> EasyDict(**d)
|
| 59 |
+
{'a': 1}
|
| 60 |
+
|
| 61 |
+
Set attributes
|
| 62 |
+
|
| 63 |
+
>>> d = EasyDict()
|
| 64 |
+
>>> d.foo = 3
|
| 65 |
+
>>> d.foo
|
| 66 |
+
3
|
| 67 |
+
>>> d.bar = {'prop': 'value'}
|
| 68 |
+
>>> d.bar.prop
|
| 69 |
+
'value'
|
| 70 |
+
>>> d
|
| 71 |
+
{'foo': 3, 'bar': {'prop': 'value'}}
|
| 72 |
+
>>> d.bar.prop = 'newer'
|
| 73 |
+
>>> d.bar.prop
|
| 74 |
+
'newer'
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
Values extraction
|
| 78 |
+
|
| 79 |
+
>>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
|
| 80 |
+
>>> isinstance(d.bar, list)
|
| 81 |
+
True
|
| 82 |
+
>>> from operator import attrgetter
|
| 83 |
+
>>> map(attrgetter('x'), d.bar)
|
| 84 |
+
[1, 3]
|
| 85 |
+
>>> map(attrgetter('y'), d.bar)
|
| 86 |
+
[2, 4]
|
| 87 |
+
>>> d = EasyDict()
|
| 88 |
+
>>> d.keys()
|
| 89 |
+
[]
|
| 90 |
+
>>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
|
| 91 |
+
>>> d.foo
|
| 92 |
+
3
|
| 93 |
+
>>> d.bar.x
|
| 94 |
+
1
|
| 95 |
+
|
| 96 |
+
Still like a dict though
|
| 97 |
+
|
| 98 |
+
>>> o = EasyDict({'clean':True})
|
| 99 |
+
>>> o.items()
|
| 100 |
+
[('clean', True)]
|
| 101 |
+
|
| 102 |
+
And like a class
|
| 103 |
+
|
| 104 |
+
>>> class Flower(EasyDict):
|
| 105 |
+
... power = 1
|
| 106 |
+
...
|
| 107 |
+
>>> f = Flower()
|
| 108 |
+
>>> f.power
|
| 109 |
+
1
|
| 110 |
+
>>> f = Flower({'height': 12})
|
| 111 |
+
>>> f.height
|
| 112 |
+
12
|
| 113 |
+
>>> f['power']
|
| 114 |
+
1
|
| 115 |
+
>>> sorted(f.keys())
|
| 116 |
+
['height', 'power']
|
| 117 |
+
|
| 118 |
+
update and pop items
|
| 119 |
+
>>> d = EasyDict(a=1, b='2')
|
| 120 |
+
>>> e = EasyDict(c=3.0, a=9.0)
|
| 121 |
+
>>> d.update(e)
|
| 122 |
+
>>> d.c
|
| 123 |
+
3.0
|
| 124 |
+
>>> d['c']
|
| 125 |
+
3.0
|
| 126 |
+
>>> d.get('c')
|
| 127 |
+
3.0
|
| 128 |
+
>>> d.update(a=4, b=4)
|
| 129 |
+
>>> d.b
|
| 130 |
+
4
|
| 131 |
+
>>> d.pop('a')
|
| 132 |
+
4
|
| 133 |
+
>>> d.a
|
| 134 |
+
Traceback (most recent call last):
|
| 135 |
+
...
|
| 136 |
+
AttributeError: 'EasyDict' object has no attribute 'a'
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, d=None, **kwargs):
|
| 140 |
+
if d is None:
|
| 141 |
+
d = {}
|
| 142 |
+
if kwargs:
|
| 143 |
+
d.update(**kwargs)
|
| 144 |
+
for k, v in d.items():
|
| 145 |
+
setattr(self, k, v)
|
| 146 |
+
# Class attributes
|
| 147 |
+
for k in self.__class__.__dict__.keys():
|
| 148 |
+
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
|
| 149 |
+
setattr(self, k, getattr(self, k))
|
| 150 |
+
|
| 151 |
+
def __setattr__(self, name, value):
|
| 152 |
+
if isinstance(value, (list, tuple)):
|
| 153 |
+
value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
|
| 154 |
+
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
| 155 |
+
value = self.__class__(value)
|
| 156 |
+
super(EasyDict, self).__setattr__(name, value)
|
| 157 |
+
super(EasyDict, self).__setitem__(name, value)
|
| 158 |
+
|
| 159 |
+
__setitem__ = __setattr__
|
| 160 |
+
|
| 161 |
+
def update(self, e=None, **f):
|
| 162 |
+
d = e or dict()
|
| 163 |
+
d.update(f)
|
| 164 |
+
for k in d:
|
| 165 |
+
setattr(self, k, d[k])
|
| 166 |
+
|
| 167 |
+
def pop(self, k, d=None):
|
| 168 |
+
if hasattr(self, k):
|
| 169 |
+
delattr(self, k)
|
| 170 |
+
return super(EasyDict, self).pop(k, d)
|
| 171 |
+
|
| 172 |
+
from transformers import PretrainedConfig
|
| 173 |
+
class Config(PretrainedConfig):
|
| 174 |
+
_auto_class = "AutoConfig"
|
| 175 |
+
"""config"""
|
| 176 |
+
def __init__(self, **kwargs):
|
| 177 |
+
super().__init__(**kwargs)
|
| 178 |
+
self.cfg=EasyDict(kwargs)
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def pretty_text(cls, cfg: dict, indent=2) -> str:
|
| 182 |
+
"""format dict to a string
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
cfg (EasyDict): the params.
|
| 186 |
+
|
| 187 |
+
Returns: The string to display.
|
| 188 |
+
|
| 189 |
+
"""
|
| 190 |
+
msg = "{\n"
|
| 191 |
+
for i, (k, v) in enumerate(cfg.items()):
|
| 192 |
+
if isinstance(v, dict):
|
| 193 |
+
v = cls.pretty_text(v, indent + 4)
|
| 194 |
+
spaces = " " * indent
|
| 195 |
+
msg += spaces + "{}: {}".format(k, v)
|
| 196 |
+
if i == len(cfg) - 1:
|
| 197 |
+
msg += " }"
|
| 198 |
+
else:
|
| 199 |
+
msg += "\n"
|
| 200 |
+
return msg
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def dump(cls, cfg, savepath=None):
|
| 204 |
+
"""dump cfg to `json` file.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
cfg (dict): The dict to dump.
|
| 208 |
+
savepath (str): The filepath to save the dumped dict.
|
| 209 |
+
|
| 210 |
+
Returns: TODO
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
if savepath is None:
|
| 214 |
+
savepath = osp.join(cfg.WORKSPACE, "config.json")
|
| 215 |
+
json.dump(cfg, open(savepath, "w"), indent=2)
|
| 216 |
+
|
| 217 |
+
@classmethod
|
| 218 |
+
def get_config(cls, default_config: dict = None, config_file: str=''):
|
| 219 |
+
"""get a `Config` instance.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
default_config (dict): The default config. `default_config` will be overrided
|
| 223 |
+
by config file `--cfg`, `--cfg` will be overrided by commandline args.
|
| 224 |
+
|
| 225 |
+
Returns: an EasyDict.
|
| 226 |
+
"""
|
| 227 |
+
global cfg
|
| 228 |
+
if cfg is not None:
|
| 229 |
+
return cfg
|
| 230 |
+
|
| 231 |
+
# define arg parser.
|
| 232 |
+
parser = argparse.ArgumentParser()
|
| 233 |
+
# parser.add_argument("--cfg", help="load configs from yaml file", default="", type=str)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--config_file", default='your config file', help="the configuration file to load. support: .yaml, .json, .py"
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--opts",
|
| 239 |
+
default=None,
|
| 240 |
+
nargs="*",
|
| 241 |
+
help="overrided configs. List. Format: 'key1 name1 key2 name2'",
|
| 242 |
+
)
|
| 243 |
+
# args = parser.parse_args()
|
| 244 |
+
args = parser.parse_known_args()[0] # for jupyterrrrrrrrrrrrrrrrrrrrrrrrr
|
| 245 |
+
args.config_file="/mnt/petrelfs/shiyansong/WEIGHT/UMT/l16_25m.py"
|
| 246 |
+
|
| 247 |
+
if config_file:
|
| 248 |
+
args.config_file=config_file
|
| 249 |
+
|
| 250 |
+
cfg = EasyDict(BASE_CONFIG)
|
| 251 |
+
# if default_config: # new------------------------------------
|
| 252 |
+
# cfg = merge_a_into_b(default_config, cfg)
|
| 253 |
+
if osp.isfile(args.config_file):
|
| 254 |
+
cfg_from_file = cls.from_file(args.config_file)
|
| 255 |
+
cfg = merge_a_into_b(cfg_from_file, cfg)
|
| 256 |
+
if args.opts:
|
| 257 |
+
cfg = cls.merge_list(cfg, args.opts)
|
| 258 |
+
cfg = eval_dict_leaf(cfg)
|
| 259 |
+
|
| 260 |
+
# update some keys to make them show at the last
|
| 261 |
+
for k in BASE_CONFIG:
|
| 262 |
+
cfg[k] = cfg.pop(k)
|
| 263 |
+
return cfg
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def from_file(cls, filepath: str) -> EasyDict:
|
| 267 |
+
"""Build config from file. Supported filetypes: `.py`,`.yaml`,`.json`.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
filepath (str): The config file path.
|
| 271 |
+
|
| 272 |
+
Returns: TODO
|
| 273 |
+
|
| 274 |
+
"""
|
| 275 |
+
filepath = osp.abspath(osp.expanduser(filepath))
|
| 276 |
+
if not osp.isfile(filepath):
|
| 277 |
+
raise IOError(f"File does not exist: {filepath}")
|
| 278 |
+
if filepath.endswith(".py"):
|
| 279 |
+
sys.path.insert(0, osp.dirname(filepath))
|
| 280 |
+
mod = import_module(osp.splitext(osp.basename(filepath))[0])
|
| 281 |
+
cfg_dict = {
|
| 282 |
+
name: value
|
| 283 |
+
for name, value in mod.__dict__.items()
|
| 284 |
+
if not name.startswith("__")
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
# I've no idea what the fuck is this, fuck it!!!
|
| 288 |
+
# with tempfile.TemporaryDirectory() as temp_config_dir:
|
| 289 |
+
# print(temp_config_dir, filepath)
|
| 290 |
+
|
| 291 |
+
# print(f"Copying {osp.dirname(filepath)} to {osp.join(temp_config_dir, 'tmp_config')}")
|
| 292 |
+
# shutil.copytree(osp.dirname(filepath), osp.join(temp_config_dir, "tmp_config"))
|
| 293 |
+
# sys.path.insert(0, temp_config_dir)
|
| 294 |
+
# mod = import_module("tmp_config." + osp.splitext(osp.basename(filepath))[0])
|
| 295 |
+
# # mod = import_module(temp_module_name)
|
| 296 |
+
# sys.path.pop(0)
|
| 297 |
+
# cfg_dict = {
|
| 298 |
+
# name: value
|
| 299 |
+
# for name, value in mod.__dict__.items()
|
| 300 |
+
# if not name.startswith("__")
|
| 301 |
+
# }
|
| 302 |
+
# print("Removing")
|
| 303 |
+
# for k in list(sys.modules.keys()):
|
| 304 |
+
# if "tmp_config" in k:
|
| 305 |
+
# del sys.modules[k]
|
| 306 |
+
elif filepath.endswith((".yml", ".yaml")):
|
| 307 |
+
cfg_dict = yaml.load(open(filepath, "r"), Loader=yaml.Loader)
|
| 308 |
+
elif filepath.endswith(".json"):
|
| 309 |
+
cfg_dict = json.load(open(filepath, "r"))
|
| 310 |
+
else:
|
| 311 |
+
raise IOError("Only py/yml/yaml/json type are supported now!")
|
| 312 |
+
|
| 313 |
+
cfg_text = filepath + "\n"
|
| 314 |
+
with open(filepath, "r") as f:
|
| 315 |
+
cfg_text += f.read()
|
| 316 |
+
|
| 317 |
+
if BASE_KEY in cfg_dict: # load configs in `BASE_KEY`
|
| 318 |
+
cfg_dir = osp.dirname(filepath)
|
| 319 |
+
base_filename = cfg_dict.pop(BASE_KEY)
|
| 320 |
+
base_filename = (
|
| 321 |
+
base_filename if isinstance(base_filename, list) else [base_filename]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
cfg_dict_list = list()
|
| 325 |
+
for f in base_filename:
|
| 326 |
+
_cfg_dict = Config.from_file(osp.join(cfg_dir, f))
|
| 327 |
+
cfg_dict_list.append(_cfg_dict)
|
| 328 |
+
|
| 329 |
+
base_cfg_dict = dict()
|
| 330 |
+
for c in cfg_dict_list:
|
| 331 |
+
if len(base_cfg_dict.keys() & c.keys()) > 0:
|
| 332 |
+
raise KeyError("Duplicate key is not allowed among bases")
|
| 333 |
+
base_cfg_dict.update(c)
|
| 334 |
+
|
| 335 |
+
cfg_dict = merge_a_into_b(cfg_dict, base_cfg_dict)
|
| 336 |
+
|
| 337 |
+
return EasyDict(cfg_dict)
|
| 338 |
+
|
| 339 |
+
@classmethod
|
| 340 |
+
def merge_list(cls, cfg, opts: list):
|
| 341 |
+
"""merge commandline opts.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
cfg: (dict): The config to be merged.
|
| 345 |
+
opts (list): The list to merge. Format: [key1, name1, key2, name2,...].
|
| 346 |
+
The keys can be nested. For example, ["a.b", v] will be considered
|
| 347 |
+
as `dict(a=dict(b=v))`.
|
| 348 |
+
|
| 349 |
+
Returns: dict.
|
| 350 |
+
|
| 351 |
+
"""
|
| 352 |
+
assert len(opts) % 2 == 0, f"length of opts must be even. Got: {opts}"
|
| 353 |
+
for i in range(0, len(opts), 2):
|
| 354 |
+
full_k, v = opts[i], opts[i + 1]
|
| 355 |
+
keys = full_k.split(".")
|
| 356 |
+
sub_d = cfg
|
| 357 |
+
for i, k in enumerate(keys):
|
| 358 |
+
if not hasattr(sub_d, k):
|
| 359 |
+
raise ValueError(f"The key {k} not exist in the config. Full key:{full_k}")
|
| 360 |
+
if i != len(keys) - 1:
|
| 361 |
+
sub_d = sub_d[k]
|
| 362 |
+
else:
|
| 363 |
+
sub_d[k] = v
|
| 364 |
+
return cfg
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def merge_a_into_b(a, b, inplace=False):
|
| 368 |
+
"""The values in a will override values in b.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
a (dict): source dict.
|
| 372 |
+
b (dict): target dict.
|
| 373 |
+
|
| 374 |
+
Returns: dict. recursively merge dict a into dict b.
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
if not inplace:
|
| 378 |
+
b = deepcopy(b)
|
| 379 |
+
for key in a:
|
| 380 |
+
if key in b:
|
| 381 |
+
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
| 382 |
+
b[key] = merge_a_into_b(a[key], b[key], inplace=True)
|
| 383 |
+
else:
|
| 384 |
+
b[key] = a[key]
|
| 385 |
+
else:
|
| 386 |
+
b[key] = a[key]
|
| 387 |
+
return b
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def eval_dict_leaf(d, orig_dict=None):
|
| 391 |
+
"""eval values of dict leaf.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
d (dict): The dict to eval.
|
| 395 |
+
|
| 396 |
+
Returns: dict.
|
| 397 |
+
|
| 398 |
+
"""
|
| 399 |
+
if orig_dict is None:
|
| 400 |
+
orig_dict = d
|
| 401 |
+
for k, v in d.items():
|
| 402 |
+
if not isinstance(v, dict):
|
| 403 |
+
d[k] = eval_string(v, orig_dict)
|
| 404 |
+
else:
|
| 405 |
+
eval_dict_leaf(v, orig_dict)
|
| 406 |
+
return d
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def eval_string(string, d):
|
| 410 |
+
"""automatically evaluate string to corresponding types.
|
| 411 |
+
|
| 412 |
+
For example:
|
| 413 |
+
not a string -> return the original input
|
| 414 |
+
'0' -> 0
|
| 415 |
+
'0.2' -> 0.2
|
| 416 |
+
'[0, 1, 2]' -> [0,1,2]
|
| 417 |
+
'eval(1+2)' -> 3
|
| 418 |
+
'eval(range(5))' -> [0,1,2,3,4]
|
| 419 |
+
'${a}' -> d.a
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
string (str): The value to evaluate.
|
| 425 |
+
d (dict): The
|
| 426 |
+
|
| 427 |
+
Returns: the corresponding type
|
| 428 |
+
|
| 429 |
+
"""
|
| 430 |
+
if not isinstance(string, str):
|
| 431 |
+
return string
|
| 432 |
+
# if len(string) > 1 and string[0] == "[" and string[-1] == "]":
|
| 433 |
+
# return eval(string)
|
| 434 |
+
if string[0:5] == "eval(":
|
| 435 |
+
return eval(string[5:-1])
|
| 436 |
+
|
| 437 |
+
s0 = string
|
| 438 |
+
s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
|
| 439 |
+
if s1 != s0:
|
| 440 |
+
while s1 != s0:
|
| 441 |
+
s0 = s1
|
| 442 |
+
s1 = re.sub(r"\${(.*)}", r"d.\1", s0)
|
| 443 |
+
return eval(s1)
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
v = ast.literal_eval(string)
|
| 447 |
+
except:
|
| 448 |
+
v = string
|
| 449 |
+
return v
|
| 450 |
+
|
| 451 |
+
if __name__=="__main__":
|
| 452 |
+
d=EasyDict({"1":2,"2":3})
|
| 453 |
+
cfg=Config({"1":2,"2":3})
|
ltm/basis_functions.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BasisFunctions(object):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
def __len__(self):
|
| 10 |
+
"""Number of basis functions."""
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def evaluate(self, t):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def integrate_t2_times_psi(self, a, b):
|
| 17 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def integrate_t_times_psi(self, a, b):
|
| 21 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def integrate_psi(self, a, b):
|
| 25 |
+
"""Compute integral int_a^b psi(t)."""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PowerBasisFunctions(BasisFunctions):
|
| 30 |
+
"""Function phi(t) = t**degree."""
|
| 31 |
+
def __init__(self, degree):
|
| 32 |
+
self.degree = degree.unsqueeze(0)
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
"""Number of basis functions."""
|
| 36 |
+
return self.degree.size(1)
|
| 37 |
+
|
| 38 |
+
def evaluate(self, t):
|
| 39 |
+
return t**self.degree
|
| 40 |
+
|
| 41 |
+
def integrate_t2_times_psi(self, a, b):
|
| 42 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 43 |
+
return (b**(self.degree + 3) - a**(self.degree + 3)) / (self.degree + 3)
|
| 44 |
+
|
| 45 |
+
def integrate_t_times_psi(self, a, b):
|
| 46 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 47 |
+
return (b**(self.degree + 2) - a**(self.degree + 2)) / (self.degree + 2)
|
| 48 |
+
|
| 49 |
+
def integrate_psi(self, a, b):
|
| 50 |
+
"""Compute integral int_a^b psi(t)."""
|
| 51 |
+
return (b**(self.degree + 1) - a**(self.degree + 1)) / (self.degree + 1)
|
| 52 |
+
|
| 53 |
+
def __repr__(self):
|
| 54 |
+
return f"PowerBasisFunction(degree={self.degree})"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SineBasisFunctions(BasisFunctions):
|
| 58 |
+
"""Function phi(t) = sin(omega*t)."""
|
| 59 |
+
def __init__(self, omega):
|
| 60 |
+
self.omega = omega.unsqueeze(0)
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
return f"SineBasisFunction(omega={self.omega})"
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
"""Number of basis functions."""
|
| 67 |
+
return self.omega.size(1)
|
| 68 |
+
|
| 69 |
+
def evaluate(self, t):
|
| 70 |
+
return torch.sin(self.omega*t)
|
| 71 |
+
|
| 72 |
+
def integrate_t2_times_psi(self, a, b):
|
| 73 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 74 |
+
# The antiderivative of (t**2)*sin(omega*t) is
|
| 75 |
+
# ((2-(t**2)*(omega**2))*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
|
| 76 |
+
return ((2-(b**2)*(self.omega**2))*torch.cos(self.omega*b)
|
| 77 |
+
+ 2*self.omega*b*torch.sin(self.omega*b)
|
| 78 |
+
- (2-(a**2)*(self.omega**2))*torch.cos(self.omega*a)
|
| 79 |
+
- 2*self.omega*a*torch.sin(self.omega*a)
|
| 80 |
+
) / (self.omega**3)
|
| 81 |
+
|
| 82 |
+
def integrate_t_times_psi(self, a, b):
|
| 83 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 84 |
+
# The antiderivative of t*sin(omega*t) is
|
| 85 |
+
# (sin(omega*t) - omega*t*cos(omega*t)) / omega**2.
|
| 86 |
+
return (torch.sin(self.omega*b) - self.omega*b*torch.cos(self.omega*b)
|
| 87 |
+
- torch.sin(self.omega*a) + self.omega*a*torch.cos(self.omega*a)
|
| 88 |
+
) / (self.omega**2)
|
| 89 |
+
|
| 90 |
+
def integrate_psi(self, a, b):
|
| 91 |
+
"""Compute integral int_a^b psi(t)."""
|
| 92 |
+
# The antiderivative of sin(omega*t) is -cos(omega*t)/omega.
|
| 93 |
+
return (-torch.cos(self.omega*b) + torch.cos(self.omega*a)) / self.omega
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class CosineBasisFunctions(BasisFunctions):
|
| 97 |
+
"""Function phi(t) = cos(omega*t)."""
|
| 98 |
+
def __init__(self, omega):
|
| 99 |
+
self.omega = omega.unsqueeze(0)
|
| 100 |
+
|
| 101 |
+
def __repr__(self):
|
| 102 |
+
return f"CosineBasisFunction(omega={self.omega})"
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
"""Number of basis functions."""
|
| 106 |
+
return self.omega.size(1)
|
| 107 |
+
|
| 108 |
+
def evaluate(self, t):
|
| 109 |
+
return torch.cos(self.omega*t)
|
| 110 |
+
|
| 111 |
+
def integrate_t2_times_psi(self, a, b):
|
| 112 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 113 |
+
# The antiderivative of (t**2)*cos(omega*t) is
|
| 114 |
+
# (((t**2)*(omega**2)-2)*cos(omega*t) + 2*omega*t*sin(omega*t)) / omega**3. # noqa
|
| 115 |
+
return (((b**2)*(self.omega**2)-2)*torch.sin(self.omega*b)
|
| 116 |
+
+ 2*self.omega*b*torch.cos(self.omega*b)
|
| 117 |
+
- ((a**2)*(self.omega**2)-2)*torch.sin(self.omega*a)
|
| 118 |
+
- 2*self.omega*a*torch.cos(self.omega*a)
|
| 119 |
+
) / (self.omega**3)
|
| 120 |
+
|
| 121 |
+
def integrate_t_times_psi(self, a, b):
|
| 122 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 123 |
+
# The antiderivative of t*cos(omega*t) is
|
| 124 |
+
# (cos(omega*t) + omega*t*sin(omega*t)) / omega**2.
|
| 125 |
+
return (torch.cos(self.omega*b) + self.omega*b*torch.sin(self.omega*b)
|
| 126 |
+
- torch.cos(self.omega*a) - self.omega*a*torch.sin(self.omega*a)
|
| 127 |
+
) / (self.omega**2)
|
| 128 |
+
|
| 129 |
+
def integrate_psi(self, a, b):
|
| 130 |
+
"""Compute integral int_a^b psi(t)."""
|
| 131 |
+
# The antiderivative of cos(omega*t) is sin(omega*t)/omega.
|
| 132 |
+
return (torch.sin(self.omega*b) - torch.sin(self.omega*a)) / self.omega
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class GaussianBasisFunctions(BasisFunctions):
|
| 136 |
+
"""Function phi(t) = Gaussian(t; mu, sigma_sq)."""
|
| 137 |
+
def __init__(self, mu, sigma):
|
| 138 |
+
self.mu = mu.unsqueeze(0)
|
| 139 |
+
self.sigma = sigma.unsqueeze(0)
|
| 140 |
+
|
| 141 |
+
def __repr__(self):
|
| 142 |
+
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
|
| 143 |
+
|
| 144 |
+
def __len__(self):
|
| 145 |
+
"""Number of basis functions."""
|
| 146 |
+
return self.mu.size(1)
|
| 147 |
+
|
| 148 |
+
def _phi(self, t):
|
| 149 |
+
return 1. / math.sqrt(2 * math.pi) * torch.exp(-.5 * t**2)
|
| 150 |
+
|
| 151 |
+
def _Phi(self, t):
|
| 152 |
+
return .5 * (1 + torch.erf(t / math.sqrt(2)))
|
| 153 |
+
|
| 154 |
+
def _integrate_product_of_gaussians(self, mu, sigma_sq):
|
| 155 |
+
sigma = torch.sqrt(self.sigma ** 2 + sigma_sq)
|
| 156 |
+
return self._phi((mu - self.mu) / sigma) / sigma
|
| 157 |
+
|
| 158 |
+
def evaluate(self, t):
|
| 159 |
+
return self._phi((t - self.mu) / self.sigma) / self.sigma
|
| 160 |
+
|
| 161 |
+
def batch_evaluate(self, t):
|
| 162 |
+
t_ = t.repeat(self.mu.size(0),1) - self.mu.repeat(t.size(0),1).transpose(1,0)
|
| 163 |
+
t_ = t_ / self.sigma.repeat((t.size(0),1)).transpose(1,0)
|
| 164 |
+
return (self._phi(t_) / self.sigma.repeat((t.size(0),1)).transpose(1,0)).transpose(0,1)
|
| 165 |
+
|
| 166 |
+
def integrate_t2_times_psi(self, a, b):
|
| 167 |
+
"""Compute integral int_a^b (t**2) * psi(t)."""
|
| 168 |
+
return (self.mu**2 + self.sigma**2) * (
|
| 169 |
+
self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 170 |
+
) - (
|
| 171 |
+
self.sigma * (b + self.mu) * self._phi((b - self.mu) / self.sigma)
|
| 172 |
+
) + (
|
| 173 |
+
self.sigma * (a + self.mu) * self._phi((a - self.mu) / self.sigma)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def integrate_t_times_psi(self, a, b):
|
| 177 |
+
"""Compute integral int_a^b t * psi(t)."""
|
| 178 |
+
return self.mu * (
|
| 179 |
+
self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 180 |
+
) - self.sigma * (
|
| 181 |
+
self._phi((b - self.mu) / self.sigma) - self._phi((a - self.mu) / self.sigma)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def integrate_psi(self, a, b):
|
| 185 |
+
"""Compute integral int_a^b psi(t)."""
|
| 186 |
+
return self._Phi((b - self.mu) / self.sigma) - self._Phi((a - self.mu) / self.sigma)
|
| 187 |
+
|
| 188 |
+
def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
|
| 189 |
+
"""Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
|
| 190 |
+
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 191 |
+
mu_tilde = (
|
| 192 |
+
self.mu * sigma_sq + mu * self.sigma ** 2
|
| 193 |
+
) / (
|
| 194 |
+
self.sigma ** 2 + sigma_sq
|
| 195 |
+
)
|
| 196 |
+
sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / (self.sigma ** 2 + sigma_sq)
|
| 197 |
+
return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde)
|
| 198 |
+
|
| 199 |
+
def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
|
| 200 |
+
"""Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
|
| 201 |
+
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 202 |
+
mu_tilde = (
|
| 203 |
+
self.mu * sigma_sq + mu * self.sigma ** 2
|
| 204 |
+
) / (
|
| 205 |
+
self.sigma ** 2 + sigma_sq
|
| 206 |
+
)
|
| 207 |
+
return S_tilde * mu_tilde
|
| 208 |
+
|
| 209 |
+
def integrate_psi_gaussian(self, mu, sigma_sq):
|
| 210 |
+
"""Compute integral int N(t; mu, sigma_sq) * psi(t)."""
|
| 211 |
+
return self._integrate_product_of_gaussians(mu, sigma_sq)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class RetangularBasisFunctions(BasisFunctions):
|
| 215 |
+
"""Function phi(t) = Gaussian(t; mu, sigma_sq)."""
|
| 216 |
+
def __init__(self, mu, sigma):
|
| 217 |
+
self.mu = mu.unsqueeze(0)
|
| 218 |
+
self.width = sigma.unsqueeze(0)
|
| 219 |
+
|
| 220 |
+
def __repr__(self):
|
| 221 |
+
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"
|
| 222 |
+
|
| 223 |
+
def __len__(self):
|
| 224 |
+
"""Number of basis functions."""
|
| 225 |
+
return self.mu.size(1)
|
| 226 |
+
|
| 227 |
+
def batch_evaluate(self, t):
|
| 228 |
+
"""
|
| 229 |
+
Evaluate multiple time points against all rectangular basis functions.
|
| 230 |
+
Args:
|
| 231 |
+
t: Tensor of time values to evaluate, shape (num_points,).
|
| 232 |
+
Returns:
|
| 233 |
+
Tensor of evaluations, shape (num_basis, num_points).
|
| 234 |
+
"""
|
| 235 |
+
t = t.repeat(self.mu.size(0),1) # Shape: (1, num_points)
|
| 236 |
+
mu = self.mu.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
|
| 237 |
+
width = self.width.repeat(t.size(0),1).transpose(1,0) # Shape: (num_basis, 1)
|
| 238 |
+
return ((t >= (mu - width / 2)) & (t < (mu + width / 2))).float().transpose(0,1)
|
| 239 |
+
|
| 240 |
+
def _Phi(self, t):
|
| 241 |
+
"""
|
| 242 |
+
Compute the step function for a single value of t.
|
| 243 |
+
Args:
|
| 244 |
+
t: A scalar or tensor of time values.
|
| 245 |
+
Returns:
|
| 246 |
+
Tensor of values indicating presence in each basis function's range.
|
| 247 |
+
"""
|
| 248 |
+
lower_bounds = self.mu - self.width / 2
|
| 249 |
+
upper_bounds = self.mu + self.width / 2
|
| 250 |
+
return ((t >= lower_bounds) & (t < upper_bounds)).float()
|
| 251 |
+
|
| 252 |
+
def evaluate(self, t):
|
| 253 |
+
"""
|
| 254 |
+
Evaluate the rectangular basis functions at a single point or array of points.
|
| 255 |
+
Args:
|
| 256 |
+
t: A scalar or 1D tensor of time values.
|
| 257 |
+
Returns:
|
| 258 |
+
Tensor of shape (num_basis,) for scalar input, or (num_basis, num_points) for tensor input.
|
| 259 |
+
"""
|
| 260 |
+
if t.ndim == 0: # Scalar input
|
| 261 |
+
return self._Phi(t)
|
| 262 |
+
else: # Tensor input
|
| 263 |
+
# Shape: (1, num_points)
|
| 264 |
+
lower_bounds = (self.mu - self.width / 2) # Shape: (num_basis, 1)
|
| 265 |
+
upper_bounds = (self.mu + self.width / 2) # Shape: (num_basis, 1)
|
| 266 |
+
return ((t >= lower_bounds) & (t < upper_bounds)).float()
|
ltm/long_term_attention_gibbs.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
"""
|
| 3 |
+
Attention modules
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.distributions as dist
|
| 8 |
+
|
| 9 |
+
from .basis_functions import (
|
| 10 |
+
PowerBasisFunctions,
|
| 11 |
+
SineBasisFunctions,
|
| 12 |
+
CosineBasisFunctions,
|
| 13 |
+
GaussianBasisFunctions,
|
| 14 |
+
RetangularBasisFunctions
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class LongTermAttention(nn.Module):
|
| 22 |
+
def __init__(self, head_size:int , length: int, target_len:int, attn_func: str, attn_num_basis: int,
|
| 23 |
+
continuous: bool, attn_drop: float, infinite_memory: bool, n_layers: int,
|
| 24 |
+
n_heads: int, affines: bool, mask: bool, mask_type: str, kl_regularizer: bool, proj_key, proj_value, sigma_0, mu_0, sticky_memories, sigmas, tau, **kwargs):
|
| 25 |
+
|
| 26 |
+
super(LongTermAttention, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.device = 'cuda'
|
| 29 |
+
self.length = length #memory length
|
| 30 |
+
self.target_len = target_len #target length / transformer length
|
| 31 |
+
self.head_size = head_size
|
| 32 |
+
self.attn_num_basis = attn_num_basis
|
| 33 |
+
self.continuous = continuous # whether attention over memory vectors is continuous
|
| 34 |
+
self.attn_func = attn_func # normalizing function
|
| 35 |
+
self.n_head = n_heads
|
| 36 |
+
self.sigmas = sigmas
|
| 37 |
+
self.kl_regularizer = kl_regularizer
|
| 38 |
+
self.sigma_0 = sigma_0
|
| 39 |
+
self.mu_0 = mu_0
|
| 40 |
+
self.proj_key = proj_key
|
| 41 |
+
self.proj_value = proj_value
|
| 42 |
+
|
| 43 |
+
self.affines=affines # whether mu, sigma should be computed using affine transformations
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
self.sticky_memories=sticky_memories
|
| 47 |
+
|
| 48 |
+
self.mem_threshold=2048
|
| 49 |
+
self.infinite_memory = infinite_memory # whether the memory is infinite
|
| 50 |
+
|
| 51 |
+
self.nb_samples=512 # number of samples used for update
|
| 52 |
+
self.tau = tau #compressing factor
|
| 53 |
+
self.count = 0
|
| 54 |
+
|
| 55 |
+
self.x_past=None # previous memory vectors
|
| 56 |
+
self.B_past=None # previous coefficient matrix
|
| 57 |
+
|
| 58 |
+
self.ridge_penalty=0.5 # ridge penalty
|
| 59 |
+
self.padding = True
|
| 60 |
+
|
| 61 |
+
self.spacing='linear'
|
| 62 |
+
|
| 63 |
+
def get_basis(self, length, target_len):
|
| 64 |
+
def compute_G(l, psi, positions, padding=True):
|
| 65 |
+
|
| 66 |
+
F = torch.zeros(self.attn_num_basis, positions.size(0))
|
| 67 |
+
|
| 68 |
+
basis_functions = psi
|
| 69 |
+
F[:, :] = basis_functions.evaluate(positions.unsqueeze(1)).t()
|
| 70 |
+
|
| 71 |
+
I = torch.eye(self.attn_num_basis)
|
| 72 |
+
G = F.t().matmul((F.matmul(F.t()) + self.ridge_penalty * I).inverse())
|
| 73 |
+
|
| 74 |
+
if padding:
|
| 75 |
+
if l % 2:
|
| 76 |
+
G = G[((l-1)//2):(-(l-1)//2), :]
|
| 77 |
+
else:
|
| 78 |
+
G = G[(l//2):-(l//2), :]
|
| 79 |
+
|
| 80 |
+
return G.to(self.device)
|
| 81 |
+
padding = self.padding
|
| 82 |
+
attn_num_basis = self.attn_num_basis
|
| 83 |
+
if self.continuous:
|
| 84 |
+
|
| 85 |
+
self.psi=[None]
|
| 86 |
+
self.Gs=[None for _ in range(length+1)]
|
| 87 |
+
lengths=[]
|
| 88 |
+
for i in range(length):
|
| 89 |
+
self.psi.append([])
|
| 90 |
+
if (i+1)%target_len==0:
|
| 91 |
+
lengths.append(i+1)
|
| 92 |
+
if length not in lengths:
|
| 93 |
+
lengths.append(length)
|
| 94 |
+
for l in lengths:
|
| 95 |
+
# get positions for memory vectors
|
| 96 |
+
self.add_retangular_basis_functions(self.psi[l], attn_num_basis, device=self.device)
|
| 97 |
+
|
| 98 |
+
if self.spacing=='linear':
|
| 99 |
+
if padding:
|
| 100 |
+
if l % 2:
|
| 101 |
+
shift = 1 / float(l)
|
| 102 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
|
| 103 |
+
else:
|
| 104 |
+
shift = 1 / float(2*l)
|
| 105 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
|
| 106 |
+
else:
|
| 107 |
+
shift = 1 / float(2*l)
|
| 108 |
+
positions = torch.linspace(shift, 1-shift, l).to(self.device)
|
| 109 |
+
elif self.spacing=='log':
|
| 110 |
+
if padding:
|
| 111 |
+
if l % 2:
|
| 112 |
+
shift = 1 / float(l)
|
| 113 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l-1).to(self.device)
|
| 114 |
+
else:
|
| 115 |
+
shift = 1 / float(2*l)
|
| 116 |
+
positions = torch.linspace(-.5+shift, 1.5-shift, 2*l).to(self.device)
|
| 117 |
+
|
| 118 |
+
pos = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
|
| 119 |
+
positions = torch.cat([positions[:int(l/2)],pos.to(self.device),positions[-int(l/2):]])
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
positions = np.e**(np.log(1+1)*torch.arange(1,length+1)/length)-1
|
| 123 |
+
|
| 124 |
+
# compute basis functions
|
| 125 |
+
self.Gs[l]=compute_G(l, self.psi[l][0], positions, padding=padding) # [L,N]
|
| 126 |
+
self.positions = positions[int(l/2):-int(l/2)]
|
| 127 |
+
|
| 128 |
+
# compute samples for memory update
|
| 129 |
+
if self.infinite_memory:
|
| 130 |
+
tm_tau = torch.arange(1,self.nb_samples+1).float()
|
| 131 |
+
tm_l = torch.arange(self.nb_samples+1,length+self.nb_samples+1).float()
|
| 132 |
+
tm_tau = tm_tau*self.tau/self.nb_samples # positions of old vectors
|
| 133 |
+
tm_l = self.tau + (1-self.tau)*(tm_l-self.nb_samples)/length # positions of new vectors
|
| 134 |
+
positions_inf = torch.cat([tm_tau, tm_l],0).to(self.device) # positions
|
| 135 |
+
|
| 136 |
+
if padding:
|
| 137 |
+
if l % 2:
|
| 138 |
+
shift = 1 / float(length+self.nb_samples)
|
| 139 |
+
positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)-1).to(self.device)
|
| 140 |
+
else:
|
| 141 |
+
shift = 1 / float(2*length+self.nb_samples)
|
| 142 |
+
positions_pad = torch.linspace(-.5+shift, 1.5-shift, 2*(length+self.nb_samples)).to(self.device)
|
| 143 |
+
positions_pad_ = torch.FloatTensor([i for i in positions_pad if i<0]).to(self.device)
|
| 144 |
+
positions_pad__ = torch.FloatTensor([i for i in positions_pad if i>1]).to(self.device)
|
| 145 |
+
positions_inf = torch.cat([positions_pad_,positions_inf,positions_pad__], dim=0)
|
| 146 |
+
|
| 147 |
+
self.samples=None
|
| 148 |
+
for t in tm_tau:
|
| 149 |
+
if self.samples is None:
|
| 150 |
+
self.samples = self.psi[l][0].evaluate(t/self.tau)
|
| 151 |
+
else:
|
| 152 |
+
self.samples = torch.cat([self.samples,self.psi[l][0].evaluate(t/self.tau)], dim=0)
|
| 153 |
+
|
| 154 |
+
# compute G for the infinite case
|
| 155 |
+
self.G_inf = compute_G(self.nb_samples+length, self.psi[l][0], positions_inf, padding=padding) #[L+nb_samples,N]
|
| 156 |
+
|
| 157 |
+
if self.sticky_memories:
|
| 158 |
+
self.bins = torch.linspace(0,1,129).to(device=self.device) #self.positions
|
| 159 |
+
self.nb_bins_cat=1
|
| 160 |
+
self.bins_cat = dist.Categorical(torch.ones(self.nb_bins_cat))
|
| 161 |
+
|
| 162 |
+
def add_gaussian_basis_functions(self, psi, nb_basis, sigmas, device):
|
| 163 |
+
mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)), torch.Tensor(sigmas))
|
| 164 |
+
mu = mu.flatten().to(device)
|
| 165 |
+
sigma = sigma.flatten().to(device)
|
| 166 |
+
self.basis_mu=mu
|
| 167 |
+
self.basis_sigma=sigma
|
| 168 |
+
assert mu.size(0) == nb_basis
|
| 169 |
+
psi.append(GaussianBasisFunctions(mu=mu, sigma=sigma))
|
| 170 |
+
|
| 171 |
+
def add_retangular_basis_functions(self, psi, nb_basis, device):
|
| 172 |
+
width = torch.ones(nb_basis, device=device) / nb_basis
|
| 173 |
+
|
| 174 |
+
# Compute the centers (midpoints) of each bin
|
| 175 |
+
edges = torch.linspace(0, 1, nb_basis + 1, device=device)
|
| 176 |
+
mu = (edges[:-1] + edges[1:]) / 2
|
| 177 |
+
psi.append(RetangularBasisFunctions(mu=mu, sigma=width))
|
| 178 |
+
|
| 179 |
+
def value_function(self, x, inf=False):
|
| 180 |
+
if inf:
|
| 181 |
+
G = self.G_inf # [nb_sample+L,N]
|
| 182 |
+
else:
|
| 183 |
+
G = self.Gs[x.size(-1)] # [L,N]
|
| 184 |
+
B = torch.matmul(x, G) # [B,e,N]
|
| 185 |
+
B = B.permute(0,2,1) # [B,N,e]
|
| 186 |
+
|
| 187 |
+
return B
|
| 188 |
+
|
| 189 |
+
def update_inf(self, x):
|
| 190 |
+
if self.B_past is not None:
|
| 191 |
+
if self.sticky_memories:
|
| 192 |
+
bins = self.bins.clone()
|
| 193 |
+
bins[0]=-.000001
|
| 194 |
+
bins[-1]=1.000001
|
| 195 |
+
prob_density = self.compute_probability(self.score, t=bins)
|
| 196 |
+
cum_prob = torch.cumulative_trapezoid(prob_density, bins, dim=-1).to(self.device)
|
| 197 |
+
p = (cum_prob[..., 1:] - cum_prob[..., :-1]).sum(dim=(1, 2))
|
| 198 |
+
p = p / p.sum(-1, keepdim=True) # Normalize over the last dimension (bins)
|
| 199 |
+
p = dist.Categorical(p)
|
| 200 |
+
b = p.sample((self.nb_samples,))
|
| 201 |
+
t = self.bins_cat.sample((self.nb_samples, 1)).to(device=self.device)
|
| 202 |
+
ts = (t*(self.bins[b+1]-self.bins[b])/self.nb_bins_cat +self.bins[b]).transpose(1,0)
|
| 203 |
+
samples = self.psi[self.length][0].batch_evaluate(ts[0]).contiguous()
|
| 204 |
+
|
| 205 |
+
xm_tau = self.B_past.transpose(-1,-2).matmul(samples.transpose(-1,-2)) # [B,e,nb_samples]
|
| 206 |
+
else:
|
| 207 |
+
xm_tau = self.B_past.transpose(-1,-2).matmul(self.samples.transpose(-1,-2)) # [B,e,nb_samples]
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
x = torch.cat([xm_tau,x], dim=2) # [B,e,nb_samples+L]
|
| 211 |
+
B = self.value_function(x, inf=True) # [B,N,e]
|
| 212 |
+
else:
|
| 213 |
+
B = self.value_function(x)
|
| 214 |
+
|
| 215 |
+
self.B_past=B.detach()
|
| 216 |
+
self.x_past=x
|
| 217 |
+
return B
|
| 218 |
+
|
| 219 |
+
def score(self, t):
|
| 220 |
+
psis = self.psis[0].batch_evaluate(t)
|
| 221 |
+
query = self.queries/ (self.d_head ** 0.5) # divide by sqrt(d_head) [B,h,q,d]
|
| 222 |
+
keys = self.keys.transpose(-1, -2)
|
| 223 |
+
keys = torch.matmul(keys, psis.T) #[B,h,d,1]
|
| 224 |
+
scores = torch.matmul(query, keys) #[B,h,q,1]
|
| 225 |
+
return scores
|
| 226 |
+
|
| 227 |
+
def compute_probability(self, score_fn, num_points=1000, t=None):
|
| 228 |
+
"""
|
| 229 |
+
Compute probability distribution p(t).
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
score_fn (callable): Function that computes z(t)
|
| 233 |
+
num_points (int): Number of points for numerical integration
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
tuple: (probabilities, normalization constant)
|
| 237 |
+
"""
|
| 238 |
+
if t is None:
|
| 239 |
+
# Create integration points
|
| 240 |
+
t = torch.linspace(0, 1, num_points).to(self.device)
|
| 241 |
+
|
| 242 |
+
scores = score_fn(t)
|
| 243 |
+
prob = torch.exp(scores) / torch.trapz(torch.exp(scores), t, dim=-1).unsqueeze(-1)
|
| 244 |
+
return prob
|
| 245 |
+
|
| 246 |
+
def expected_value(self, score_fn, num_points=1000):
|
| 247 |
+
"""
|
| 248 |
+
Compute expected value E_p[V(t)] using nested integration.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
score_fn (callable): Function that computes z(t)
|
| 252 |
+
value_fn (callable): Function that computes v(t)
|
| 253 |
+
num_points (int): Number of points for numerical integration
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
torch.Tensor: Expected value
|
| 257 |
+
"""
|
| 258 |
+
# Create integration points
|
| 259 |
+
t = torch.linspace(0, 1, num_points).to(self.device)
|
| 260 |
+
|
| 261 |
+
# Compute basis functions
|
| 262 |
+
self.psis = []
|
| 263 |
+
self.add_retangular_basis_functions(self.psis, self.attn_num_basis, self.device)
|
| 264 |
+
psi = self.psis[0].batch_evaluate(t)
|
| 265 |
+
# Compute probability distribution
|
| 266 |
+
prob = self.compute_probability(score_fn, num_points)
|
| 267 |
+
# Compute values at integration points
|
| 268 |
+
values = self.values
|
| 269 |
+
# Compute p(t) * psi(t)
|
| 270 |
+
# Reshape psi for broadcasting to match the shape of prob
|
| 271 |
+
psi_broadcasted = psi.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
| 272 |
+
|
| 273 |
+
# Expand psi to match the dimensions of prob (num_points, batch_size, n_head, qlen, 256)
|
| 274 |
+
psi_broadcasted = psi_broadcasted.expand(num_points, self.batch_size, self.n_head, self.qlen, self.attn_num_basis)
|
| 275 |
+
integrand = torch.matmul(prob.permute(3,0,1,2).unsqueeze(-1).unsqueeze(-1), psi_broadcasted.unsqueeze(-2)).permute(1, 2, 3, 4, 5, 0).squeeze(-3)
|
| 276 |
+
|
| 277 |
+
integral = torch.trapz(integrand, t, dim=-1)
|
| 278 |
+
# Matrix multiply with values
|
| 279 |
+
expected_value = torch.matmul(integral, values) # [B, h, q, d]
|
| 280 |
+
|
| 281 |
+
return expected_value
|
| 282 |
+
|
| 283 |
+
def forward(self, k, q, new_doc, layer_n):
|
| 284 |
+
self.device = k.device
|
| 285 |
+
if self.continuous:
|
| 286 |
+
klen = int(k.size(1)/(14*14))
|
| 287 |
+
self.length = klen
|
| 288 |
+
batch_size = k.size(0) #batch size
|
| 289 |
+
qlen = q.size(1) #query length
|
| 290 |
+
self.qlen = qlen
|
| 291 |
+
self.batch_size = batch_size
|
| 292 |
+
self.d_head = self.head_size #head size
|
| 293 |
+
self.get_basis(klen, klen)
|
| 294 |
+
# clean memory if going through different document
|
| 295 |
+
if new_doc:
|
| 296 |
+
self.B_past=None
|
| 297 |
+
self.x_past=None
|
| 298 |
+
|
| 299 |
+
k = k.reshape(batch_size, klen, 14, 14, 1024).mean(dim=(2, 3))
|
| 300 |
+
k = k.transpose(1,2)
|
| 301 |
+
# perform memory update
|
| 302 |
+
if self.infinite_memory:
|
| 303 |
+
B = self.update_inf(k)
|
| 304 |
+
else: # compute input continuous approximation
|
| 305 |
+
B = self.value_function(k) # [B,N,e]
|
| 306 |
+
keys = self.proj_key(B)
|
| 307 |
+
values = self.proj_value(B)
|
| 308 |
+
query = q
|
| 309 |
+
self.queries = query.view(batch_size,qlen,self.n_head,self.d_head).transpose(1,2) # [B,h,q,d]
|
| 310 |
+
self.keys = keys.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B,h,N,d]
|
| 311 |
+
self.values = values.view(batch_size,self.attn_num_basis,self.n_head,self.d_head).transpose(1,2) # [B, h, q, N]
|
| 312 |
+
context = self.expected_value(self.score) # Shape [1, 32, 768]
|
| 313 |
+
|
| 314 |
+
return context.contiguous().transpose(1,2).reshape(1, qlen, -1)
|
| 315 |
+
|
model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecf6a804c5af89465362453e591d8c3358cd97ad48247baabfc5b070edad2e07
|
| 3 |
+
size 4971600800
|
model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf5745ae7b321d884e62f74589758abee57e79c6ae138e1b1f6877b5cad20565
|
| 3 |
+
size 4915917440
|
model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:475cbd791fe87314409771c7f9651e5f7237c43e8eb5d9662714ff1d3d4fbc04
|
| 3 |
+
size 4999820720
|
model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7648e67deaaa08ce6f73df0f96963c62dba9702927390a73c69bdc328d6f5d27
|
| 3 |
+
size 1499540784
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,934 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 16386763776
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"extra_query_tokens": "model-00001-of-00004.safetensors",
|
| 7 |
+
"mistral_model.lm_head.weight": "model-00004-of-00004.safetensors",
|
| 8 |
+
"mistral_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 9 |
+
"mistral_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 10 |
+
"mistral_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 11 |
+
"mistral_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 12 |
+
"mistral_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 13 |
+
"mistral_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 14 |
+
"mistral_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 15 |
+
"mistral_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 16 |
+
"mistral_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 17 |
+
"mistral_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 18 |
+
"mistral_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 19 |
+
"mistral_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 20 |
+
"mistral_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 21 |
+
"mistral_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 22 |
+
"mistral_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 23 |
+
"mistral_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 24 |
+
"mistral_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"mistral_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 26 |
+
"mistral_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"mistral_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 28 |
+
"mistral_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 29 |
+
"mistral_model.model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 30 |
+
"mistral_model.model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 31 |
+
"mistral_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 32 |
+
"mistral_model.model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 33 |
+
"mistral_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 34 |
+
"mistral_model.model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 35 |
+
"mistral_model.model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 36 |
+
"mistral_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 37 |
+
"mistral_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 38 |
+
"mistral_model.model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 39 |
+
"mistral_model.model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 40 |
+
"mistral_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 41 |
+
"mistral_model.model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 42 |
+
"mistral_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 43 |
+
"mistral_model.model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 44 |
+
"mistral_model.model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 45 |
+
"mistral_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 46 |
+
"mistral_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 47 |
+
"mistral_model.model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 48 |
+
"mistral_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 49 |
+
"mistral_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 50 |
+
"mistral_model.model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"mistral_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"mistral_model.model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 53 |
+
"mistral_model.model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 54 |
+
"mistral_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 55 |
+
"mistral_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 56 |
+
"mistral_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"mistral_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 58 |
+
"mistral_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"mistral_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 60 |
+
"mistral_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 61 |
+
"mistral_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 62 |
+
"mistral_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 63 |
+
"mistral_model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"mistral_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 65 |
+
"mistral_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"mistral_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 67 |
+
"mistral_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"mistral_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 69 |
+
"mistral_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 70 |
+
"mistral_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 71 |
+
"mistral_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 72 |
+
"mistral_model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 73 |
+
"mistral_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 74 |
+
"mistral_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 75 |
+
"mistral_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 76 |
+
"mistral_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 77 |
+
"mistral_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"mistral_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 79 |
+
"mistral_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 80 |
+
"mistral_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"mistral_model.model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 82 |
+
"mistral_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"mistral_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 84 |
+
"mistral_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"mistral_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 86 |
+
"mistral_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 87 |
+
"mistral_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 88 |
+
"mistral_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 89 |
+
"mistral_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 90 |
+
"mistral_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 91 |
+
"mistral_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 92 |
+
"mistral_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 93 |
+
"mistral_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"mistral_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 95 |
+
"mistral_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 96 |
+
"mistral_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 97 |
+
"mistral_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 98 |
+
"mistral_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 99 |
+
"mistral_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 100 |
+
"mistral_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 101 |
+
"mistral_model.model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 102 |
+
"mistral_model.model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 103 |
+
"mistral_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 104 |
+
"mistral_model.model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 105 |
+
"mistral_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 106 |
+
"mistral_model.model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 107 |
+
"mistral_model.model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 108 |
+
"mistral_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 109 |
+
"mistral_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 110 |
+
"mistral_model.model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 111 |
+
"mistral_model.model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 112 |
+
"mistral_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 113 |
+
"mistral_model.model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 114 |
+
"mistral_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 115 |
+
"mistral_model.model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 116 |
+
"mistral_model.model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 117 |
+
"mistral_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 118 |
+
"mistral_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 119 |
+
"mistral_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 120 |
+
"mistral_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 121 |
+
"mistral_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 122 |
+
"mistral_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 123 |
+
"mistral_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 124 |
+
"mistral_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 125 |
+
"mistral_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 126 |
+
"mistral_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 127 |
+
"mistral_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 128 |
+
"mistral_model.model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 129 |
+
"mistral_model.model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 130 |
+
"mistral_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 131 |
+
"mistral_model.model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 132 |
+
"mistral_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 133 |
+
"mistral_model.model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 134 |
+
"mistral_model.model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 135 |
+
"mistral_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 136 |
+
"mistral_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 137 |
+
"mistral_model.model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 138 |
+
"mistral_model.model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 139 |
+
"mistral_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 140 |
+
"mistral_model.model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 141 |
+
"mistral_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 142 |
+
"mistral_model.model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 143 |
+
"mistral_model.model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 144 |
+
"mistral_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 145 |
+
"mistral_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 146 |
+
"mistral_model.model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 147 |
+
"mistral_model.model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 148 |
+
"mistral_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 149 |
+
"mistral_model.model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 150 |
+
"mistral_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 151 |
+
"mistral_model.model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 152 |
+
"mistral_model.model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 153 |
+
"mistral_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 154 |
+
"mistral_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 155 |
+
"mistral_model.model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 156 |
+
"mistral_model.model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 157 |
+
"mistral_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 158 |
+
"mistral_model.model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 159 |
+
"mistral_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 160 |
+
"mistral_model.model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 161 |
+
"mistral_model.model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 162 |
+
"mistral_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 163 |
+
"mistral_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 164 |
+
"mistral_model.model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 165 |
+
"mistral_model.model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 166 |
+
"mistral_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 167 |
+
"mistral_model.model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 168 |
+
"mistral_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 169 |
+
"mistral_model.model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 170 |
+
"mistral_model.model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 171 |
+
"mistral_model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 172 |
+
"mistral_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 173 |
+
"mistral_model.model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 174 |
+
"mistral_model.model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 175 |
+
"mistral_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 176 |
+
"mistral_model.model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 177 |
+
"mistral_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 178 |
+
"mistral_model.model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 179 |
+
"mistral_model.model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 180 |
+
"mistral_model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 181 |
+
"mistral_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 182 |
+
"mistral_model.model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 183 |
+
"mistral_model.model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 184 |
+
"mistral_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 185 |
+
"mistral_model.model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 186 |
+
"mistral_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 187 |
+
"mistral_model.model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 188 |
+
"mistral_model.model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 189 |
+
"mistral_model.model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 190 |
+
"mistral_model.model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 191 |
+
"mistral_model.model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 192 |
+
"mistral_model.model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 193 |
+
"mistral_model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 194 |
+
"mistral_model.model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 195 |
+
"mistral_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 196 |
+
"mistral_model.model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 197 |
+
"mistral_model.model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 198 |
+
"mistral_model.model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 199 |
+
"mistral_model.model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 200 |
+
"mistral_model.model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 201 |
+
"mistral_model.model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 202 |
+
"mistral_model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 203 |
+
"mistral_model.model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 204 |
+
"mistral_model.model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 205 |
+
"mistral_model.model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 206 |
+
"mistral_model.model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 207 |
+
"mistral_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 208 |
+
"mistral_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 209 |
+
"mistral_model.model.layers.29.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 210 |
+
"mistral_model.model.layers.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 211 |
+
"mistral_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 212 |
+
"mistral_model.model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 213 |
+
"mistral_model.model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 214 |
+
"mistral_model.model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 215 |
+
"mistral_model.model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 216 |
+
"mistral_model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 217 |
+
"mistral_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 218 |
+
"mistral_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 219 |
+
"mistral_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 220 |
+
"mistral_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 221 |
+
"mistral_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 222 |
+
"mistral_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 223 |
+
"mistral_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 224 |
+
"mistral_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 225 |
+
"mistral_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 226 |
+
"mistral_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 227 |
+
"mistral_model.model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 228 |
+
"mistral_model.model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 229 |
+
"mistral_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 230 |
+
"mistral_model.model.layers.30.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 231 |
+
"mistral_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 232 |
+
"mistral_model.model.layers.30.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 233 |
+
"mistral_model.model.layers.30.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 234 |
+
"mistral_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 235 |
+
"mistral_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
| 236 |
+
"mistral_model.model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
| 237 |
+
"mistral_model.model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
| 238 |
+
"mistral_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
| 239 |
+
"mistral_model.model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
| 240 |
+
"mistral_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
| 241 |
+
"mistral_model.model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
| 242 |
+
"mistral_model.model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
| 243 |
+
"mistral_model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 244 |
+
"mistral_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 245 |
+
"mistral_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 246 |
+
"mistral_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 247 |
+
"mistral_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 248 |
+
"mistral_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 249 |
+
"mistral_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 250 |
+
"mistral_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 251 |
+
"mistral_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 252 |
+
"mistral_model.model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 253 |
+
"mistral_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 254 |
+
"mistral_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 255 |
+
"mistral_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 256 |
+
"mistral_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 257 |
+
"mistral_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 258 |
+
"mistral_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 259 |
+
"mistral_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 260 |
+
"mistral_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 261 |
+
"mistral_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 262 |
+
"mistral_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 263 |
+
"mistral_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 264 |
+
"mistral_model.model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 265 |
+
"mistral_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 266 |
+
"mistral_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 267 |
+
"mistral_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 268 |
+
"mistral_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 269 |
+
"mistral_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 270 |
+
"mistral_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 271 |
+
"mistral_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 272 |
+
"mistral_model.model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 273 |
+
"mistral_model.model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 274 |
+
"mistral_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 275 |
+
"mistral_model.model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 276 |
+
"mistral_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 277 |
+
"mistral_model.model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 278 |
+
"mistral_model.model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 279 |
+
"mistral_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 280 |
+
"mistral_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 281 |
+
"mistral_model.model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 282 |
+
"mistral_model.model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 283 |
+
"mistral_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 284 |
+
"mistral_model.model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 285 |
+
"mistral_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 286 |
+
"mistral_model.model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 287 |
+
"mistral_model.model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 288 |
+
"mistral_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 289 |
+
"mistral_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 290 |
+
"mistral_model.model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 291 |
+
"mistral_model.model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 292 |
+
"mistral_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 293 |
+
"mistral_model.model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 294 |
+
"mistral_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 295 |
+
"mistral_model.model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 296 |
+
"mistral_model.model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 297 |
+
"mistral_model.model.norm.weight": "model-00004-of-00004.safetensors",
|
| 298 |
+
"mistral_proj.bias": "model-00004-of-00004.safetensors",
|
| 299 |
+
"mistral_proj.weight": "model-00004-of-00004.safetensors",
|
| 300 |
+
"qformer.bert.embeddings.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 301 |
+
"qformer.bert.embeddings.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 302 |
+
"qformer.bert.embeddings.position_embeddings.weight": "model-00001-of-00004.safetensors",
|
| 303 |
+
"qformer.bert.embeddings.position_ids": "model-00001-of-00004.safetensors",
|
| 304 |
+
"qformer.bert.embeddings.word_embeddings.weight": "model-00001-of-00004.safetensors",
|
| 305 |
+
"qformer.bert.encoder.layer.0.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 306 |
+
"qformer.bert.encoder.layer.0.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 307 |
+
"qformer.bert.encoder.layer.0.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 308 |
+
"qformer.bert.encoder.layer.0.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 309 |
+
"qformer.bert.encoder.layer.0.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 310 |
+
"qformer.bert.encoder.layer.0.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 311 |
+
"qformer.bert.encoder.layer.0.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 312 |
+
"qformer.bert.encoder.layer.0.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 313 |
+
"qformer.bert.encoder.layer.0.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 314 |
+
"qformer.bert.encoder.layer.0.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 315 |
+
"qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 316 |
+
"qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 317 |
+
"qformer.bert.encoder.layer.0.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 318 |
+
"qformer.bert.encoder.layer.0.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 319 |
+
"qformer.bert.encoder.layer.0.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 320 |
+
"qformer.bert.encoder.layer.0.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 321 |
+
"qformer.bert.encoder.layer.0.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 322 |
+
"qformer.bert.encoder.layer.0.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 323 |
+
"qformer.bert.encoder.layer.0.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 324 |
+
"qformer.bert.encoder.layer.0.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 325 |
+
"qformer.bert.encoder.layer.0.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 326 |
+
"qformer.bert.encoder.layer.0.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 327 |
+
"qformer.bert.encoder.layer.0.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 328 |
+
"qformer.bert.encoder.layer.0.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 329 |
+
"qformer.bert.encoder.layer.0.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 330 |
+
"qformer.bert.encoder.layer.0.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 331 |
+
"qformer.bert.encoder.layer.0.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 332 |
+
"qformer.bert.encoder.layer.0.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 333 |
+
"qformer.bert.encoder.layer.0.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 334 |
+
"qformer.bert.encoder.layer.0.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 335 |
+
"qformer.bert.encoder.layer.0.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 336 |
+
"qformer.bert.encoder.layer.0.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 337 |
+
"qformer.bert.encoder.layer.1.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 338 |
+
"qformer.bert.encoder.layer.1.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 339 |
+
"qformer.bert.encoder.layer.1.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 340 |
+
"qformer.bert.encoder.layer.1.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 341 |
+
"qformer.bert.encoder.layer.1.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 342 |
+
"qformer.bert.encoder.layer.1.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 343 |
+
"qformer.bert.encoder.layer.1.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 344 |
+
"qformer.bert.encoder.layer.1.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 345 |
+
"qformer.bert.encoder.layer.1.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 346 |
+
"qformer.bert.encoder.layer.1.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 347 |
+
"qformer.bert.encoder.layer.1.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 348 |
+
"qformer.bert.encoder.layer.1.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 349 |
+
"qformer.bert.encoder.layer.1.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 350 |
+
"qformer.bert.encoder.layer.1.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 351 |
+
"qformer.bert.encoder.layer.1.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 352 |
+
"qformer.bert.encoder.layer.1.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 353 |
+
"qformer.bert.encoder.layer.1.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 354 |
+
"qformer.bert.encoder.layer.1.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 355 |
+
"qformer.bert.encoder.layer.1.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 356 |
+
"qformer.bert.encoder.layer.1.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 357 |
+
"qformer.bert.encoder.layer.1.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 358 |
+
"qformer.bert.encoder.layer.1.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 359 |
+
"qformer.bert.encoder.layer.10.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 360 |
+
"qformer.bert.encoder.layer.10.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 361 |
+
"qformer.bert.encoder.layer.10.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 362 |
+
"qformer.bert.encoder.layer.10.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 363 |
+
"qformer.bert.encoder.layer.10.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 364 |
+
"qformer.bert.encoder.layer.10.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 365 |
+
"qformer.bert.encoder.layer.10.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 366 |
+
"qformer.bert.encoder.layer.10.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 367 |
+
"qformer.bert.encoder.layer.10.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 368 |
+
"qformer.bert.encoder.layer.10.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 369 |
+
"qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 370 |
+
"qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 371 |
+
"qformer.bert.encoder.layer.10.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 372 |
+
"qformer.bert.encoder.layer.10.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 373 |
+
"qformer.bert.encoder.layer.10.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 374 |
+
"qformer.bert.encoder.layer.10.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 375 |
+
"qformer.bert.encoder.layer.10.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 376 |
+
"qformer.bert.encoder.layer.10.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 377 |
+
"qformer.bert.encoder.layer.10.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 378 |
+
"qformer.bert.encoder.layer.10.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 379 |
+
"qformer.bert.encoder.layer.10.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 380 |
+
"qformer.bert.encoder.layer.10.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 381 |
+
"qformer.bert.encoder.layer.10.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 382 |
+
"qformer.bert.encoder.layer.10.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 383 |
+
"qformer.bert.encoder.layer.10.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 384 |
+
"qformer.bert.encoder.layer.10.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 385 |
+
"qformer.bert.encoder.layer.10.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 386 |
+
"qformer.bert.encoder.layer.10.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 387 |
+
"qformer.bert.encoder.layer.10.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 388 |
+
"qformer.bert.encoder.layer.10.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 389 |
+
"qformer.bert.encoder.layer.10.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 390 |
+
"qformer.bert.encoder.layer.10.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 391 |
+
"qformer.bert.encoder.layer.11.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 392 |
+
"qformer.bert.encoder.layer.11.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 393 |
+
"qformer.bert.encoder.layer.11.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 394 |
+
"qformer.bert.encoder.layer.11.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 395 |
+
"qformer.bert.encoder.layer.11.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 396 |
+
"qformer.bert.encoder.layer.11.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 397 |
+
"qformer.bert.encoder.layer.11.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 398 |
+
"qformer.bert.encoder.layer.11.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 399 |
+
"qformer.bert.encoder.layer.11.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 400 |
+
"qformer.bert.encoder.layer.11.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 401 |
+
"qformer.bert.encoder.layer.11.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 402 |
+
"qformer.bert.encoder.layer.11.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 403 |
+
"qformer.bert.encoder.layer.11.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 404 |
+
"qformer.bert.encoder.layer.11.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 405 |
+
"qformer.bert.encoder.layer.11.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 406 |
+
"qformer.bert.encoder.layer.11.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 407 |
+
"qformer.bert.encoder.layer.11.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 408 |
+
"qformer.bert.encoder.layer.11.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 409 |
+
"qformer.bert.encoder.layer.11.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 410 |
+
"qformer.bert.encoder.layer.11.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 411 |
+
"qformer.bert.encoder.layer.11.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 412 |
+
"qformer.bert.encoder.layer.11.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 413 |
+
"qformer.bert.encoder.layer.2.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 414 |
+
"qformer.bert.encoder.layer.2.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 415 |
+
"qformer.bert.encoder.layer.2.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 416 |
+
"qformer.bert.encoder.layer.2.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 417 |
+
"qformer.bert.encoder.layer.2.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 418 |
+
"qformer.bert.encoder.layer.2.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 419 |
+
"qformer.bert.encoder.layer.2.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 420 |
+
"qformer.bert.encoder.layer.2.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 421 |
+
"qformer.bert.encoder.layer.2.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 422 |
+
"qformer.bert.encoder.layer.2.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 423 |
+
"qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 424 |
+
"qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 425 |
+
"qformer.bert.encoder.layer.2.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 426 |
+
"qformer.bert.encoder.layer.2.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 427 |
+
"qformer.bert.encoder.layer.2.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 428 |
+
"qformer.bert.encoder.layer.2.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 429 |
+
"qformer.bert.encoder.layer.2.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 430 |
+
"qformer.bert.encoder.layer.2.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 431 |
+
"qformer.bert.encoder.layer.2.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 432 |
+
"qformer.bert.encoder.layer.2.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 433 |
+
"qformer.bert.encoder.layer.2.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 434 |
+
"qformer.bert.encoder.layer.2.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 435 |
+
"qformer.bert.encoder.layer.2.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 436 |
+
"qformer.bert.encoder.layer.2.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 437 |
+
"qformer.bert.encoder.layer.2.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 438 |
+
"qformer.bert.encoder.layer.2.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 439 |
+
"qformer.bert.encoder.layer.2.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 440 |
+
"qformer.bert.encoder.layer.2.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 441 |
+
"qformer.bert.encoder.layer.2.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 442 |
+
"qformer.bert.encoder.layer.2.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 443 |
+
"qformer.bert.encoder.layer.2.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 444 |
+
"qformer.bert.encoder.layer.2.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 445 |
+
"qformer.bert.encoder.layer.3.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 446 |
+
"qformer.bert.encoder.layer.3.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 447 |
+
"qformer.bert.encoder.layer.3.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 448 |
+
"qformer.bert.encoder.layer.3.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 449 |
+
"qformer.bert.encoder.layer.3.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 450 |
+
"qformer.bert.encoder.layer.3.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 451 |
+
"qformer.bert.encoder.layer.3.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 452 |
+
"qformer.bert.encoder.layer.3.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 453 |
+
"qformer.bert.encoder.layer.3.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 454 |
+
"qformer.bert.encoder.layer.3.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 455 |
+
"qformer.bert.encoder.layer.3.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 456 |
+
"qformer.bert.encoder.layer.3.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 457 |
+
"qformer.bert.encoder.layer.3.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 458 |
+
"qformer.bert.encoder.layer.3.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 459 |
+
"qformer.bert.encoder.layer.3.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 460 |
+
"qformer.bert.encoder.layer.3.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 461 |
+
"qformer.bert.encoder.layer.3.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 462 |
+
"qformer.bert.encoder.layer.3.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 463 |
+
"qformer.bert.encoder.layer.3.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 464 |
+
"qformer.bert.encoder.layer.3.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 465 |
+
"qformer.bert.encoder.layer.3.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 466 |
+
"qformer.bert.encoder.layer.3.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 467 |
+
"qformer.bert.encoder.layer.4.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 468 |
+
"qformer.bert.encoder.layer.4.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 469 |
+
"qformer.bert.encoder.layer.4.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 470 |
+
"qformer.bert.encoder.layer.4.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 471 |
+
"qformer.bert.encoder.layer.4.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 472 |
+
"qformer.bert.encoder.layer.4.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 473 |
+
"qformer.bert.encoder.layer.4.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 474 |
+
"qformer.bert.encoder.layer.4.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 475 |
+
"qformer.bert.encoder.layer.4.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 476 |
+
"qformer.bert.encoder.layer.4.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 477 |
+
"qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 478 |
+
"qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 479 |
+
"qformer.bert.encoder.layer.4.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 480 |
+
"qformer.bert.encoder.layer.4.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 481 |
+
"qformer.bert.encoder.layer.4.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 482 |
+
"qformer.bert.encoder.layer.4.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 483 |
+
"qformer.bert.encoder.layer.4.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 484 |
+
"qformer.bert.encoder.layer.4.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 485 |
+
"qformer.bert.encoder.layer.4.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 486 |
+
"qformer.bert.encoder.layer.4.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 487 |
+
"qformer.bert.encoder.layer.4.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 488 |
+
"qformer.bert.encoder.layer.4.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 489 |
+
"qformer.bert.encoder.layer.4.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 490 |
+
"qformer.bert.encoder.layer.4.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 491 |
+
"qformer.bert.encoder.layer.4.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 492 |
+
"qformer.bert.encoder.layer.4.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 493 |
+
"qformer.bert.encoder.layer.4.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 494 |
+
"qformer.bert.encoder.layer.4.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 495 |
+
"qformer.bert.encoder.layer.4.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 496 |
+
"qformer.bert.encoder.layer.4.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 497 |
+
"qformer.bert.encoder.layer.4.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 498 |
+
"qformer.bert.encoder.layer.4.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 499 |
+
"qformer.bert.encoder.layer.5.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 500 |
+
"qformer.bert.encoder.layer.5.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 501 |
+
"qformer.bert.encoder.layer.5.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 502 |
+
"qformer.bert.encoder.layer.5.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 503 |
+
"qformer.bert.encoder.layer.5.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 504 |
+
"qformer.bert.encoder.layer.5.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 505 |
+
"qformer.bert.encoder.layer.5.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 506 |
+
"qformer.bert.encoder.layer.5.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 507 |
+
"qformer.bert.encoder.layer.5.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 508 |
+
"qformer.bert.encoder.layer.5.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 509 |
+
"qformer.bert.encoder.layer.5.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 510 |
+
"qformer.bert.encoder.layer.5.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 511 |
+
"qformer.bert.encoder.layer.5.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 512 |
+
"qformer.bert.encoder.layer.5.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 513 |
+
"qformer.bert.encoder.layer.5.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 514 |
+
"qformer.bert.encoder.layer.5.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 515 |
+
"qformer.bert.encoder.layer.5.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 516 |
+
"qformer.bert.encoder.layer.5.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 517 |
+
"qformer.bert.encoder.layer.5.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 518 |
+
"qformer.bert.encoder.layer.5.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 519 |
+
"qformer.bert.encoder.layer.5.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 520 |
+
"qformer.bert.encoder.layer.5.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 521 |
+
"qformer.bert.encoder.layer.6.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 522 |
+
"qformer.bert.encoder.layer.6.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 523 |
+
"qformer.bert.encoder.layer.6.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 524 |
+
"qformer.bert.encoder.layer.6.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 525 |
+
"qformer.bert.encoder.layer.6.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 526 |
+
"qformer.bert.encoder.layer.6.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 527 |
+
"qformer.bert.encoder.layer.6.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 528 |
+
"qformer.bert.encoder.layer.6.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 529 |
+
"qformer.bert.encoder.layer.6.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 530 |
+
"qformer.bert.encoder.layer.6.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 531 |
+
"qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 532 |
+
"qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 533 |
+
"qformer.bert.encoder.layer.6.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 534 |
+
"qformer.bert.encoder.layer.6.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 535 |
+
"qformer.bert.encoder.layer.6.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 536 |
+
"qformer.bert.encoder.layer.6.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 537 |
+
"qformer.bert.encoder.layer.6.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 538 |
+
"qformer.bert.encoder.layer.6.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 539 |
+
"qformer.bert.encoder.layer.6.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 540 |
+
"qformer.bert.encoder.layer.6.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 541 |
+
"qformer.bert.encoder.layer.6.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 542 |
+
"qformer.bert.encoder.layer.6.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 543 |
+
"qformer.bert.encoder.layer.6.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 544 |
+
"qformer.bert.encoder.layer.6.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 545 |
+
"qformer.bert.encoder.layer.6.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 546 |
+
"qformer.bert.encoder.layer.6.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 547 |
+
"qformer.bert.encoder.layer.6.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 548 |
+
"qformer.bert.encoder.layer.6.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 549 |
+
"qformer.bert.encoder.layer.6.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 550 |
+
"qformer.bert.encoder.layer.6.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 551 |
+
"qformer.bert.encoder.layer.6.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 552 |
+
"qformer.bert.encoder.layer.6.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 553 |
+
"qformer.bert.encoder.layer.7.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 554 |
+
"qformer.bert.encoder.layer.7.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 555 |
+
"qformer.bert.encoder.layer.7.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 556 |
+
"qformer.bert.encoder.layer.7.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 557 |
+
"qformer.bert.encoder.layer.7.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 558 |
+
"qformer.bert.encoder.layer.7.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 559 |
+
"qformer.bert.encoder.layer.7.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 560 |
+
"qformer.bert.encoder.layer.7.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 561 |
+
"qformer.bert.encoder.layer.7.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 562 |
+
"qformer.bert.encoder.layer.7.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 563 |
+
"qformer.bert.encoder.layer.7.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 564 |
+
"qformer.bert.encoder.layer.7.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 565 |
+
"qformer.bert.encoder.layer.7.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 566 |
+
"qformer.bert.encoder.layer.7.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 567 |
+
"qformer.bert.encoder.layer.7.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 568 |
+
"qformer.bert.encoder.layer.7.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 569 |
+
"qformer.bert.encoder.layer.7.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 570 |
+
"qformer.bert.encoder.layer.7.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 571 |
+
"qformer.bert.encoder.layer.7.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 572 |
+
"qformer.bert.encoder.layer.7.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 573 |
+
"qformer.bert.encoder.layer.7.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 574 |
+
"qformer.bert.encoder.layer.7.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 575 |
+
"qformer.bert.encoder.layer.8.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 576 |
+
"qformer.bert.encoder.layer.8.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 577 |
+
"qformer.bert.encoder.layer.8.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 578 |
+
"qformer.bert.encoder.layer.8.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 579 |
+
"qformer.bert.encoder.layer.8.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 580 |
+
"qformer.bert.encoder.layer.8.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 581 |
+
"qformer.bert.encoder.layer.8.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 582 |
+
"qformer.bert.encoder.layer.8.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 583 |
+
"qformer.bert.encoder.layer.8.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 584 |
+
"qformer.bert.encoder.layer.8.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 585 |
+
"qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 586 |
+
"qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 587 |
+
"qformer.bert.encoder.layer.8.crossattention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 588 |
+
"qformer.bert.encoder.layer.8.crossattention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 589 |
+
"qformer.bert.encoder.layer.8.crossattention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 590 |
+
"qformer.bert.encoder.layer.8.crossattention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 591 |
+
"qformer.bert.encoder.layer.8.crossattention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 592 |
+
"qformer.bert.encoder.layer.8.crossattention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 593 |
+
"qformer.bert.encoder.layer.8.crossattention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 594 |
+
"qformer.bert.encoder.layer.8.crossattention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 595 |
+
"qformer.bert.encoder.layer.8.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 596 |
+
"qformer.bert.encoder.layer.8.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 597 |
+
"qformer.bert.encoder.layer.8.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 598 |
+
"qformer.bert.encoder.layer.8.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 599 |
+
"qformer.bert.encoder.layer.8.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 600 |
+
"qformer.bert.encoder.layer.8.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 601 |
+
"qformer.bert.encoder.layer.8.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 602 |
+
"qformer.bert.encoder.layer.8.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 603 |
+
"qformer.bert.encoder.layer.8.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 604 |
+
"qformer.bert.encoder.layer.8.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 605 |
+
"qformer.bert.encoder.layer.8.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 606 |
+
"qformer.bert.encoder.layer.8.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 607 |
+
"qformer.bert.encoder.layer.9.attention.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 608 |
+
"qformer.bert.encoder.layer.9.attention.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 609 |
+
"qformer.bert.encoder.layer.9.attention.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 610 |
+
"qformer.bert.encoder.layer.9.attention.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 611 |
+
"qformer.bert.encoder.layer.9.attention.self.key.bias": "model-00001-of-00004.safetensors",
|
| 612 |
+
"qformer.bert.encoder.layer.9.attention.self.key.weight": "model-00001-of-00004.safetensors",
|
| 613 |
+
"qformer.bert.encoder.layer.9.attention.self.query.bias": "model-00001-of-00004.safetensors",
|
| 614 |
+
"qformer.bert.encoder.layer.9.attention.self.query.weight": "model-00001-of-00004.safetensors",
|
| 615 |
+
"qformer.bert.encoder.layer.9.attention.self.value.bias": "model-00001-of-00004.safetensors",
|
| 616 |
+
"qformer.bert.encoder.layer.9.attention.self.value.weight": "model-00001-of-00004.safetensors",
|
| 617 |
+
"qformer.bert.encoder.layer.9.intermediate.dense.bias": "model-00001-of-00004.safetensors",
|
| 618 |
+
"qformer.bert.encoder.layer.9.intermediate.dense.weight": "model-00001-of-00004.safetensors",
|
| 619 |
+
"qformer.bert.encoder.layer.9.intermediate_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 620 |
+
"qformer.bert.encoder.layer.9.intermediate_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 621 |
+
"qformer.bert.encoder.layer.9.output.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 622 |
+
"qformer.bert.encoder.layer.9.output.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 623 |
+
"qformer.bert.encoder.layer.9.output.dense.bias": "model-00001-of-00004.safetensors",
|
| 624 |
+
"qformer.bert.encoder.layer.9.output.dense.weight": "model-00001-of-00004.safetensors",
|
| 625 |
+
"qformer.bert.encoder.layer.9.output_query.LayerNorm.bias": "model-00001-of-00004.safetensors",
|
| 626 |
+
"qformer.bert.encoder.layer.9.output_query.LayerNorm.weight": "model-00001-of-00004.safetensors",
|
| 627 |
+
"qformer.bert.encoder.layer.9.output_query.dense.bias": "model-00001-of-00004.safetensors",
|
| 628 |
+
"qformer.bert.encoder.layer.9.output_query.dense.weight": "model-00001-of-00004.safetensors",
|
| 629 |
+
"query_tokens": "model-00001-of-00004.safetensors",
|
| 630 |
+
"vision_encoder.encoder.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 631 |
+
"vision_encoder.encoder.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 632 |
+
"vision_encoder.encoder.blocks.0.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 633 |
+
"vision_encoder.encoder.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 634 |
+
"vision_encoder.encoder.blocks.0.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 635 |
+
"vision_encoder.encoder.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 636 |
+
"vision_encoder.encoder.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 637 |
+
"vision_encoder.encoder.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 638 |
+
"vision_encoder.encoder.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 639 |
+
"vision_encoder.encoder.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
|
| 640 |
+
"vision_encoder.encoder.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
|
| 641 |
+
"vision_encoder.encoder.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
|
| 642 |
+
"vision_encoder.encoder.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
|
| 643 |
+
"vision_encoder.encoder.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 644 |
+
"vision_encoder.encoder.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 645 |
+
"vision_encoder.encoder.blocks.1.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 646 |
+
"vision_encoder.encoder.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 647 |
+
"vision_encoder.encoder.blocks.1.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 648 |
+
"vision_encoder.encoder.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 649 |
+
"vision_encoder.encoder.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 650 |
+
"vision_encoder.encoder.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 651 |
+
"vision_encoder.encoder.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 652 |
+
"vision_encoder.encoder.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
|
| 653 |
+
"vision_encoder.encoder.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
|
| 654 |
+
"vision_encoder.encoder.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
|
| 655 |
+
"vision_encoder.encoder.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
|
| 656 |
+
"vision_encoder.encoder.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 657 |
+
"vision_encoder.encoder.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 658 |
+
"vision_encoder.encoder.blocks.10.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 659 |
+
"vision_encoder.encoder.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 660 |
+
"vision_encoder.encoder.blocks.10.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 661 |
+
"vision_encoder.encoder.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 662 |
+
"vision_encoder.encoder.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 663 |
+
"vision_encoder.encoder.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 664 |
+
"vision_encoder.encoder.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 665 |
+
"vision_encoder.encoder.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
|
| 666 |
+
"vision_encoder.encoder.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
|
| 667 |
+
"vision_encoder.encoder.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
|
| 668 |
+
"vision_encoder.encoder.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
|
| 669 |
+
"vision_encoder.encoder.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 670 |
+
"vision_encoder.encoder.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 671 |
+
"vision_encoder.encoder.blocks.11.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 672 |
+
"vision_encoder.encoder.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 673 |
+
"vision_encoder.encoder.blocks.11.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 674 |
+
"vision_encoder.encoder.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 675 |
+
"vision_encoder.encoder.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 676 |
+
"vision_encoder.encoder.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 677 |
+
"vision_encoder.encoder.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 678 |
+
"vision_encoder.encoder.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
|
| 679 |
+
"vision_encoder.encoder.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
|
| 680 |
+
"vision_encoder.encoder.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
|
| 681 |
+
"vision_encoder.encoder.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
|
| 682 |
+
"vision_encoder.encoder.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 683 |
+
"vision_encoder.encoder.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 684 |
+
"vision_encoder.encoder.blocks.12.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 685 |
+
"vision_encoder.encoder.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 686 |
+
"vision_encoder.encoder.blocks.12.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 687 |
+
"vision_encoder.encoder.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 688 |
+
"vision_encoder.encoder.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 689 |
+
"vision_encoder.encoder.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 690 |
+
"vision_encoder.encoder.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 691 |
+
"vision_encoder.encoder.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
|
| 692 |
+
"vision_encoder.encoder.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
|
| 693 |
+
"vision_encoder.encoder.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
|
| 694 |
+
"vision_encoder.encoder.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
|
| 695 |
+
"vision_encoder.encoder.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 696 |
+
"vision_encoder.encoder.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 697 |
+
"vision_encoder.encoder.blocks.13.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 698 |
+
"vision_encoder.encoder.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 699 |
+
"vision_encoder.encoder.blocks.13.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 700 |
+
"vision_encoder.encoder.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 701 |
+
"vision_encoder.encoder.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 702 |
+
"vision_encoder.encoder.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 703 |
+
"vision_encoder.encoder.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 704 |
+
"vision_encoder.encoder.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
|
| 705 |
+
"vision_encoder.encoder.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
|
| 706 |
+
"vision_encoder.encoder.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
|
| 707 |
+
"vision_encoder.encoder.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
|
| 708 |
+
"vision_encoder.encoder.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 709 |
+
"vision_encoder.encoder.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 710 |
+
"vision_encoder.encoder.blocks.14.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 711 |
+
"vision_encoder.encoder.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 712 |
+
"vision_encoder.encoder.blocks.14.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 713 |
+
"vision_encoder.encoder.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 714 |
+
"vision_encoder.encoder.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 715 |
+
"vision_encoder.encoder.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 716 |
+
"vision_encoder.encoder.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 717 |
+
"vision_encoder.encoder.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
|
| 718 |
+
"vision_encoder.encoder.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
|
| 719 |
+
"vision_encoder.encoder.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
|
| 720 |
+
"vision_encoder.encoder.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
|
| 721 |
+
"vision_encoder.encoder.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 722 |
+
"vision_encoder.encoder.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 723 |
+
"vision_encoder.encoder.blocks.15.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 724 |
+
"vision_encoder.encoder.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 725 |
+
"vision_encoder.encoder.blocks.15.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 726 |
+
"vision_encoder.encoder.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 727 |
+
"vision_encoder.encoder.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 728 |
+
"vision_encoder.encoder.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 729 |
+
"vision_encoder.encoder.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 730 |
+
"vision_encoder.encoder.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
|
| 731 |
+
"vision_encoder.encoder.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
|
| 732 |
+
"vision_encoder.encoder.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
|
| 733 |
+
"vision_encoder.encoder.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
|
| 734 |
+
"vision_encoder.encoder.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 735 |
+
"vision_encoder.encoder.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 736 |
+
"vision_encoder.encoder.blocks.16.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 737 |
+
"vision_encoder.encoder.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 738 |
+
"vision_encoder.encoder.blocks.16.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 739 |
+
"vision_encoder.encoder.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 740 |
+
"vision_encoder.encoder.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 741 |
+
"vision_encoder.encoder.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 742 |
+
"vision_encoder.encoder.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 743 |
+
"vision_encoder.encoder.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
|
| 744 |
+
"vision_encoder.encoder.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
|
| 745 |
+
"vision_encoder.encoder.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
|
| 746 |
+
"vision_encoder.encoder.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
|
| 747 |
+
"vision_encoder.encoder.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 748 |
+
"vision_encoder.encoder.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 749 |
+
"vision_encoder.encoder.blocks.17.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 750 |
+
"vision_encoder.encoder.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 751 |
+
"vision_encoder.encoder.blocks.17.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 752 |
+
"vision_encoder.encoder.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 753 |
+
"vision_encoder.encoder.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 754 |
+
"vision_encoder.encoder.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 755 |
+
"vision_encoder.encoder.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 756 |
+
"vision_encoder.encoder.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
|
| 757 |
+
"vision_encoder.encoder.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
|
| 758 |
+
"vision_encoder.encoder.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
|
| 759 |
+
"vision_encoder.encoder.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
|
| 760 |
+
"vision_encoder.encoder.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 761 |
+
"vision_encoder.encoder.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 762 |
+
"vision_encoder.encoder.blocks.18.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 763 |
+
"vision_encoder.encoder.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 764 |
+
"vision_encoder.encoder.blocks.18.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 765 |
+
"vision_encoder.encoder.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 766 |
+
"vision_encoder.encoder.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 767 |
+
"vision_encoder.encoder.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 768 |
+
"vision_encoder.encoder.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 769 |
+
"vision_encoder.encoder.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
|
| 770 |
+
"vision_encoder.encoder.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
|
| 771 |
+
"vision_encoder.encoder.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
|
| 772 |
+
"vision_encoder.encoder.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
|
| 773 |
+
"vision_encoder.encoder.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 774 |
+
"vision_encoder.encoder.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 775 |
+
"vision_encoder.encoder.blocks.19.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 776 |
+
"vision_encoder.encoder.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 777 |
+
"vision_encoder.encoder.blocks.19.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 778 |
+
"vision_encoder.encoder.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 779 |
+
"vision_encoder.encoder.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 780 |
+
"vision_encoder.encoder.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 781 |
+
"vision_encoder.encoder.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 782 |
+
"vision_encoder.encoder.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
|
| 783 |
+
"vision_encoder.encoder.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
|
| 784 |
+
"vision_encoder.encoder.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
|
| 785 |
+
"vision_encoder.encoder.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
|
| 786 |
+
"vision_encoder.encoder.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 787 |
+
"vision_encoder.encoder.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 788 |
+
"vision_encoder.encoder.blocks.2.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 789 |
+
"vision_encoder.encoder.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 790 |
+
"vision_encoder.encoder.blocks.2.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 791 |
+
"vision_encoder.encoder.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 792 |
+
"vision_encoder.encoder.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 793 |
+
"vision_encoder.encoder.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 794 |
+
"vision_encoder.encoder.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 795 |
+
"vision_encoder.encoder.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
|
| 796 |
+
"vision_encoder.encoder.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
|
| 797 |
+
"vision_encoder.encoder.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
|
| 798 |
+
"vision_encoder.encoder.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
|
| 799 |
+
"vision_encoder.encoder.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 800 |
+
"vision_encoder.encoder.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 801 |
+
"vision_encoder.encoder.blocks.20.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 802 |
+
"vision_encoder.encoder.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 803 |
+
"vision_encoder.encoder.blocks.20.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 804 |
+
"vision_encoder.encoder.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 805 |
+
"vision_encoder.encoder.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 806 |
+
"vision_encoder.encoder.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 807 |
+
"vision_encoder.encoder.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 808 |
+
"vision_encoder.encoder.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
|
| 809 |
+
"vision_encoder.encoder.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
|
| 810 |
+
"vision_encoder.encoder.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
|
| 811 |
+
"vision_encoder.encoder.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
|
| 812 |
+
"vision_encoder.encoder.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 813 |
+
"vision_encoder.encoder.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 814 |
+
"vision_encoder.encoder.blocks.21.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 815 |
+
"vision_encoder.encoder.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 816 |
+
"vision_encoder.encoder.blocks.21.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 817 |
+
"vision_encoder.encoder.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 818 |
+
"vision_encoder.encoder.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 819 |
+
"vision_encoder.encoder.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 820 |
+
"vision_encoder.encoder.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 821 |
+
"vision_encoder.encoder.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
|
| 822 |
+
"vision_encoder.encoder.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
|
| 823 |
+
"vision_encoder.encoder.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
|
| 824 |
+
"vision_encoder.encoder.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
|
| 825 |
+
"vision_encoder.encoder.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 826 |
+
"vision_encoder.encoder.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 827 |
+
"vision_encoder.encoder.blocks.22.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 828 |
+
"vision_encoder.encoder.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 829 |
+
"vision_encoder.encoder.blocks.22.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 830 |
+
"vision_encoder.encoder.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 831 |
+
"vision_encoder.encoder.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 832 |
+
"vision_encoder.encoder.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 833 |
+
"vision_encoder.encoder.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 834 |
+
"vision_encoder.encoder.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
|
| 835 |
+
"vision_encoder.encoder.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
|
| 836 |
+
"vision_encoder.encoder.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
|
| 837 |
+
"vision_encoder.encoder.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
|
| 838 |
+
"vision_encoder.encoder.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 839 |
+
"vision_encoder.encoder.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 840 |
+
"vision_encoder.encoder.blocks.3.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 841 |
+
"vision_encoder.encoder.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 842 |
+
"vision_encoder.encoder.blocks.3.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 843 |
+
"vision_encoder.encoder.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 844 |
+
"vision_encoder.encoder.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 845 |
+
"vision_encoder.encoder.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 846 |
+
"vision_encoder.encoder.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 847 |
+
"vision_encoder.encoder.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
|
| 848 |
+
"vision_encoder.encoder.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
|
| 849 |
+
"vision_encoder.encoder.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
|
| 850 |
+
"vision_encoder.encoder.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
|
| 851 |
+
"vision_encoder.encoder.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 852 |
+
"vision_encoder.encoder.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 853 |
+
"vision_encoder.encoder.blocks.4.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 854 |
+
"vision_encoder.encoder.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 855 |
+
"vision_encoder.encoder.blocks.4.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 856 |
+
"vision_encoder.encoder.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 857 |
+
"vision_encoder.encoder.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 858 |
+
"vision_encoder.encoder.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 859 |
+
"vision_encoder.encoder.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 860 |
+
"vision_encoder.encoder.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
|
| 861 |
+
"vision_encoder.encoder.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
|
| 862 |
+
"vision_encoder.encoder.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
|
| 863 |
+
"vision_encoder.encoder.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
|
| 864 |
+
"vision_encoder.encoder.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 865 |
+
"vision_encoder.encoder.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 866 |
+
"vision_encoder.encoder.blocks.5.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 867 |
+
"vision_encoder.encoder.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 868 |
+
"vision_encoder.encoder.blocks.5.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 869 |
+
"vision_encoder.encoder.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 870 |
+
"vision_encoder.encoder.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 871 |
+
"vision_encoder.encoder.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 872 |
+
"vision_encoder.encoder.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 873 |
+
"vision_encoder.encoder.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
|
| 874 |
+
"vision_encoder.encoder.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
|
| 875 |
+
"vision_encoder.encoder.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
|
| 876 |
+
"vision_encoder.encoder.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
|
| 877 |
+
"vision_encoder.encoder.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 878 |
+
"vision_encoder.encoder.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 879 |
+
"vision_encoder.encoder.blocks.6.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 880 |
+
"vision_encoder.encoder.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 881 |
+
"vision_encoder.encoder.blocks.6.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 882 |
+
"vision_encoder.encoder.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 883 |
+
"vision_encoder.encoder.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 884 |
+
"vision_encoder.encoder.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 885 |
+
"vision_encoder.encoder.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 886 |
+
"vision_encoder.encoder.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
|
| 887 |
+
"vision_encoder.encoder.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
|
| 888 |
+
"vision_encoder.encoder.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
|
| 889 |
+
"vision_encoder.encoder.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
|
| 890 |
+
"vision_encoder.encoder.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 891 |
+
"vision_encoder.encoder.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 892 |
+
"vision_encoder.encoder.blocks.7.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 893 |
+
"vision_encoder.encoder.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 894 |
+
"vision_encoder.encoder.blocks.7.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 895 |
+
"vision_encoder.encoder.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 896 |
+
"vision_encoder.encoder.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 897 |
+
"vision_encoder.encoder.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 898 |
+
"vision_encoder.encoder.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 899 |
+
"vision_encoder.encoder.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
|
| 900 |
+
"vision_encoder.encoder.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
|
| 901 |
+
"vision_encoder.encoder.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
|
| 902 |
+
"vision_encoder.encoder.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
|
| 903 |
+
"vision_encoder.encoder.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 904 |
+
"vision_encoder.encoder.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 905 |
+
"vision_encoder.encoder.blocks.8.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 906 |
+
"vision_encoder.encoder.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 907 |
+
"vision_encoder.encoder.blocks.8.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 908 |
+
"vision_encoder.encoder.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 909 |
+
"vision_encoder.encoder.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 910 |
+
"vision_encoder.encoder.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 911 |
+
"vision_encoder.encoder.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 912 |
+
"vision_encoder.encoder.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
|
| 913 |
+
"vision_encoder.encoder.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
|
| 914 |
+
"vision_encoder.encoder.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
|
| 915 |
+
"vision_encoder.encoder.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
|
| 916 |
+
"vision_encoder.encoder.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
|
| 917 |
+
"vision_encoder.encoder.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
|
| 918 |
+
"vision_encoder.encoder.blocks.9.attn.q_bias": "model-00001-of-00004.safetensors",
|
| 919 |
+
"vision_encoder.encoder.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
|
| 920 |
+
"vision_encoder.encoder.blocks.9.attn.v_bias": "model-00001-of-00004.safetensors",
|
| 921 |
+
"vision_encoder.encoder.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors",
|
| 922 |
+
"vision_encoder.encoder.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors",
|
| 923 |
+
"vision_encoder.encoder.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors",
|
| 924 |
+
"vision_encoder.encoder.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors",
|
| 925 |
+
"vision_encoder.encoder.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
|
| 926 |
+
"vision_encoder.encoder.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
|
| 927 |
+
"vision_encoder.encoder.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
|
| 928 |
+
"vision_encoder.encoder.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
|
| 929 |
+
"vision_encoder.encoder.patch_embed.proj.bias": "model-00001-of-00004.safetensors",
|
| 930 |
+
"vision_encoder.encoder.patch_embed.proj.weight": "model-00001-of-00004.safetensors",
|
| 931 |
+
"vision_layernorm.bias": "model-00001-of-00004.safetensors",
|
| 932 |
+
"vision_layernorm.weight": "model-00001-of-00004.safetensors"
|
| 933 |
+
}
|
| 934 |
+
}
|
videochat2_it_hd_mistral.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.cuda.amp import autocast as autocast
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 9 |
+
|
| 10 |
+
from .blip2 import Blip2Base, disabled_train
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
from easydict import EasyDict
|
| 16 |
+
from .configuration_videochat2 import Config
|
| 17 |
+
|
| 18 |
+
class VideoChat2_it_hd_mistral(Blip2Base):
|
| 19 |
+
_auto_class='AutoModel'
|
| 20 |
+
config_class=Config
|
| 21 |
+
"""
|
| 22 |
+
VideoChat2 model.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, config):
|
| 25 |
+
super().__init__()
|
| 26 |
+
# pretrained_path
|
| 27 |
+
self.config=config
|
| 28 |
+
if isinstance(config,(PretrainedConfig,AutoConfig)):
|
| 29 |
+
if hasattr(config,'cfg'): # my own cfg
|
| 30 |
+
config=EasyDict(config.cfg)
|
| 31 |
+
else:
|
| 32 |
+
config=EasyDict(config.to_dict())
|
| 33 |
+
pc=PretrainedConfig()
|
| 34 |
+
pc.update(config)
|
| 35 |
+
vit_blip_model_path = config.get("vit_blip_model_path", None)
|
| 36 |
+
mistral_model_path = config.get("mistral_model_path")
|
| 37 |
+
videochat2_model_path = config.get("videochat2_model_path", "")
|
| 38 |
+
freeze_vit = config.get("freeze_vit", True)
|
| 39 |
+
freeze_qformer = config.get("freeze_qformer", True)
|
| 40 |
+
freeze_llm = config.get("freeze_llm", True)
|
| 41 |
+
# vit
|
| 42 |
+
low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu
|
| 43 |
+
# qformer
|
| 44 |
+
num_query_token = config.get("num_query_token")
|
| 45 |
+
qformer_hidden_dropout_prob = config.get("qformer_hidden_dropout_prob", 0.1)
|
| 46 |
+
qformer_attention_probs_dropout_prob = config.get("qformer_attention_probs_dropout_prob", 0.1)
|
| 47 |
+
qformer_drop_path_rate = config.get("qformer_drop_path_rate", 0.1)
|
| 48 |
+
extra_num_query_token = config.get("extra_num_query_token", 32)
|
| 49 |
+
self.qformer_text_input = config.get("qformer_text_input", False)
|
| 50 |
+
|
| 51 |
+
# Infinite-Video related hyperparameters
|
| 52 |
+
num_basis = config.get("num_basis", 256)
|
| 53 |
+
sticky = config.get("sticky", True)
|
| 54 |
+
tau = config.get("tau", 0.75)
|
| 55 |
+
alpha = config.get("alpha", 0.75)
|
| 56 |
+
|
| 57 |
+
# prompt
|
| 58 |
+
max_txt_len = config.get("max_txt_len", 32)
|
| 59 |
+
self.human_start = "[INST]"
|
| 60 |
+
self.human_end = "[/INST]"
|
| 61 |
+
self.assist_end = "</s>"
|
| 62 |
+
self.start_token = config.get("start_token", "<Video>")
|
| 63 |
+
self.end_token = config.get("end_token", "</Video>")
|
| 64 |
+
self.img_start_token = config.get("img_start_token", "<Image>")
|
| 65 |
+
self.img_end_token = config.get("img_end_token", "</Image>")
|
| 66 |
+
logger.info(f"Add instruction in qformer: {self.qformer_text_input}")
|
| 67 |
+
# debug
|
| 68 |
+
self.debug = config.get("debug", False)
|
| 69 |
+
self.llm_bf16 = config.get("llm_bf16", False)
|
| 70 |
+
use_flash_attention = config.get("use_flash_attention", False)
|
| 71 |
+
self.use_lora = config.get("use_lora", False)
|
| 72 |
+
lora_r = config.get("lora_r", 8)
|
| 73 |
+
lora_alpha = config.get("lora_alpha", 32)
|
| 74 |
+
lora_dropout = config.get("lora_dropout", 0.05)
|
| 75 |
+
# dynamic resolution
|
| 76 |
+
self.local_size = config.dynamic_config.get("local_size", 224)
|
| 77 |
+
self.add_global = config.dynamic_config.get("add_global", True)
|
| 78 |
+
|
| 79 |
+
self.tokenizer = self.init_tokenizer(truncation_side="left")
|
| 80 |
+
self.tokenizer.padding_side = "left"
|
| 81 |
+
self.low_resource = low_resource
|
| 82 |
+
self.vision_encoder, self.vision_layernorm = self.init_vision_encoder_umt(config)
|
| 83 |
+
self.qformer, self.query_tokens = self.init_Qformer(
|
| 84 |
+
num_query_token, config.vision_encoder.encoder_embed_dim,
|
| 85 |
+
qformer_hidden_dropout_prob=qformer_hidden_dropout_prob,
|
| 86 |
+
qformer_attention_probs_dropout_prob=qformer_attention_probs_dropout_prob,
|
| 87 |
+
qformer_drop_path_rate=qformer_drop_path_rate,
|
| 88 |
+
num_basis=num_basis, alpha=alpha, tau=tau, sticky=sticky,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if not self.qformer_text_input:
|
| 92 |
+
self.qformer.bert.embeddings.word_embeddings = None
|
| 93 |
+
self.qformer.bert.embeddings.position_embeddings = None
|
| 94 |
+
for layer in self.qformer.bert.encoder.layer:
|
| 95 |
+
layer.output = None
|
| 96 |
+
layer.intermediate = None
|
| 97 |
+
else:
|
| 98 |
+
self.qformer.resize_token_embeddings(len(self.tokenizer))
|
| 99 |
+
self.qformer.cls = None
|
| 100 |
+
|
| 101 |
+
if vit_blip_model_path:
|
| 102 |
+
logger.info(f"Load ViT and QFormer from {vit_blip_model_path}")
|
| 103 |
+
state_dict = torch.load(vit_blip_model_path, map_location="cpu")
|
| 104 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
| 105 |
+
logger.info(msg)
|
| 106 |
+
logger.info('Loading ViT and Q-Former Done')
|
| 107 |
+
|
| 108 |
+
self.extra_num_query_token = extra_num_query_token
|
| 109 |
+
if extra_num_query_token > 0:
|
| 110 |
+
logger.info(f"Add extra {extra_num_query_token} tokens in QFormer")
|
| 111 |
+
self.extra_query_tokens = nn.Parameter(
|
| 112 |
+
torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1])
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if freeze_vit:
|
| 116 |
+
logger.info("freeze vision encoder")
|
| 117 |
+
for _, param in self.vision_encoder.named_parameters():
|
| 118 |
+
param.requires_grad = False
|
| 119 |
+
self.vision_encoder = self.vision_encoder.eval()
|
| 120 |
+
self.vision_encoder.train = disabled_train
|
| 121 |
+
for _, param in self.vision_layernorm.named_parameters():
|
| 122 |
+
param.requires_grad = False
|
| 123 |
+
self.vision_layernorm = self.vision_layernorm.eval()
|
| 124 |
+
self.vision_layernorm.train = disabled_train
|
| 125 |
+
|
| 126 |
+
if freeze_qformer:
|
| 127 |
+
logger.info("freeze Qformer")
|
| 128 |
+
for _, param in self.qformer.named_parameters():
|
| 129 |
+
param.requires_grad = False
|
| 130 |
+
self.qformer = self.qformer.eval()
|
| 131 |
+
self.qformer.train = disabled_train
|
| 132 |
+
self.query_tokens.requires_grad = False
|
| 133 |
+
|
| 134 |
+
logger.info('Loading Mistral')
|
| 135 |
+
self.mistral_tokenizer = AutoTokenizer.from_pretrained(mistral_model_path)
|
| 136 |
+
self.mistral_tokenizer.padding_side = "left"
|
| 137 |
+
if not self.mistral_tokenizer.pad_token:
|
| 138 |
+
logger.info("Set pad_token")
|
| 139 |
+
self.mistral_tokenizer.pad_token = self.mistral_tokenizer.eos_token
|
| 140 |
+
|
| 141 |
+
if self.debug:
|
| 142 |
+
logger.info("Debug mode, build small Mistral")
|
| 143 |
+
mistral_config = AutoConfig.from_pretrained(mistral_model_path)
|
| 144 |
+
mistral_config.hidden_size = 512
|
| 145 |
+
mistral_config.intermediate_size = 2048
|
| 146 |
+
mistral_config.num_attention_heads = 8
|
| 147 |
+
mistral_config.num_hidden_layers = 12
|
| 148 |
+
mistral_config.torch_dtype = torch.float16
|
| 149 |
+
self.mistral_model = AutoModelForCausalLM.from_config(mistral_config)
|
| 150 |
+
else:
|
| 151 |
+
if use_flash_attention:
|
| 152 |
+
self.mistral_model = AutoModelForCausalLM.from_pretrained(
|
| 153 |
+
mistral_model_path,
|
| 154 |
+
torch_dtype=torch.bfloat16 if self.llm_bf16 else torch.float16,
|
| 155 |
+
# use_flash_attention_2=True,
|
| 156 |
+
attn_implementation="flash_attention_2",
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
self.mistral_model = AutoModelForCausalLM.from_pretrained(
|
| 160 |
+
mistral_model_path,
|
| 161 |
+
torch_dtype=torch.bfloat16 if self.llm_bf16 else torch.float16,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if freeze_llm:
|
| 165 |
+
logger.info("freeze Mistral")
|
| 166 |
+
for _, param in self.mistral_model.named_parameters():
|
| 167 |
+
param.requires_grad = False
|
| 168 |
+
logger.info('Loading Mistral Done')
|
| 169 |
+
|
| 170 |
+
if self.use_lora:
|
| 171 |
+
logger.info("Use lora")
|
| 172 |
+
peft_config = LoraConfig(
|
| 173 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
| 174 |
+
r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
| 175 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 176 |
+
"gate_proj", "up_proj", "down_proj", "lm_head"]
|
| 177 |
+
)
|
| 178 |
+
self.mistral_model = get_peft_model(self.mistral_model, peft_config)
|
| 179 |
+
if not freeze_llm:
|
| 180 |
+
logger.info("Unfreeze Mistral")
|
| 181 |
+
for _, param in self.mistral_model.base_model.named_parameters():
|
| 182 |
+
param.requires_grad = True
|
| 183 |
+
self.mistral_model.print_trainable_parameters()
|
| 184 |
+
|
| 185 |
+
self.mistral_proj = nn.Linear(
|
| 186 |
+
self.qformer.config.hidden_size, self.mistral_model.config.hidden_size
|
| 187 |
+
)
|
| 188 |
+
self.max_txt_len = max_txt_len
|
| 189 |
+
|
| 190 |
+
# load weights of VideoChat2
|
| 191 |
+
if videochat2_model_path:
|
| 192 |
+
logger.info(f"Load VideoChat2 from: {videochat2_model_path}")
|
| 193 |
+
ckpt = torch.load(videochat2_model_path, map_location="cpu")
|
| 194 |
+
if 'model' in ckpt.keys():
|
| 195 |
+
msg = self.load_state_dict(ckpt['model'], strict=False)
|
| 196 |
+
else:
|
| 197 |
+
msg = self.load_state_dict(ckpt, strict=False)
|
| 198 |
+
logger.info(msg)
|
| 199 |
+
self.config=pc
|
| 200 |
+
|
| 201 |
+
def vit_to_cpu(self):
|
| 202 |
+
self.vision_layernorm.to("cpu")
|
| 203 |
+
self.vision_layernorm.float()
|
| 204 |
+
self.vision_encoder.to("cpu")
|
| 205 |
+
self.vision_encoder.float()
|
| 206 |
+
|
| 207 |
+
def encode_img(self, image, instruction, new_video=False):
|
| 208 |
+
device = image[0].device
|
| 209 |
+
if self.low_resource:
|
| 210 |
+
self.vit_to_cpu()
|
| 211 |
+
image = [img.to("cpu") for img in image]
|
| 212 |
+
|
| 213 |
+
with self.maybe_autocast():
|
| 214 |
+
# split the image or video according to the shape
|
| 215 |
+
shapes = []
|
| 216 |
+
input_imgs = []
|
| 217 |
+
input_instructions = []
|
| 218 |
+
for idx, img in enumerate(image):
|
| 219 |
+
# logger.info(f"Input shape: {img.shape}")
|
| 220 |
+
T, C, H, W = img.shape
|
| 221 |
+
shapes.append([H//self.local_size, W//self.local_size])
|
| 222 |
+
sub_img = img.reshape(
|
| 223 |
+
1, T, 3, H//self.local_size, self.local_size, W//self.local_size, self.local_size
|
| 224 |
+
).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T, 3, self.local_size, self.local_size).contiguous()
|
| 225 |
+
input_imgs.append(sub_img)
|
| 226 |
+
input_instructions.extend([instruction[idx]] * len(sub_img))
|
| 227 |
+
if self.add_global:
|
| 228 |
+
glb_img = F.interpolate(
|
| 229 |
+
img.float(), size=(self.local_size, self.local_size), mode='bicubic', align_corners=False
|
| 230 |
+
).to(sub_img.dtype)
|
| 231 |
+
input_imgs.append(glb_img.unsqueeze(0))
|
| 232 |
+
input_instructions.append(instruction[idx])
|
| 233 |
+
input_imgs = torch.cat(input_imgs, dim=0)
|
| 234 |
+
|
| 235 |
+
T = input_imgs.shape[1]
|
| 236 |
+
use_image = True if T == 1 else False
|
| 237 |
+
input_imgs = input_imgs.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
|
| 238 |
+
|
| 239 |
+
image_embeds = self.vision_encoder(input_imgs, use_image)
|
| 240 |
+
B, T, L, C = image_embeds.shape
|
| 241 |
+
image_embeds = image_embeds.reshape(B, -1, C)
|
| 242 |
+
image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C]
|
| 243 |
+
|
| 244 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
| 245 |
+
|
| 246 |
+
if self.extra_num_query_token > 0:
|
| 247 |
+
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
|
| 248 |
+
else:
|
| 249 |
+
query_tokens = self.query_tokens
|
| 250 |
+
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 251 |
+
if self.qformer_text_input:
|
| 252 |
+
text_Qformer = self.tokenizer(
|
| 253 |
+
input_instructions,
|
| 254 |
+
padding='longest',
|
| 255 |
+
truncation=True,
|
| 256 |
+
max_length=self.max_txt_len,
|
| 257 |
+
return_tensors="pt",
|
| 258 |
+
).to(image_embeds.device)
|
| 259 |
+
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device)
|
| 260 |
+
Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
|
| 261 |
+
|
| 262 |
+
query_output = self.qformer.bert(
|
| 263 |
+
text_Qformer.input_ids,
|
| 264 |
+
attention_mask=Qformer_atts,
|
| 265 |
+
query_embeds=query_tokens,
|
| 266 |
+
encoder_hidden_states=image_embeds,
|
| 267 |
+
encoder_attention_mask=image_atts,
|
| 268 |
+
return_dict=True,
|
| 269 |
+
new_video=new_video,
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
query_output = self.qformer.bert(
|
| 273 |
+
query_embeds=query_tokens,
|
| 274 |
+
encoder_hidden_states=image_embeds,
|
| 275 |
+
encoder_attention_mask=image_atts,
|
| 276 |
+
return_dict=True,
|
| 277 |
+
new_video=new_video
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
qformer_features = self.mistral_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
|
| 281 |
+
q_C = qformer_features.shape[-1]
|
| 282 |
+
|
| 283 |
+
# merge the features from different split
|
| 284 |
+
# stolen from https://huggingface.co/internlm/internlm-xcomposer2-4khd-7b/blob/main/build_mlp.py#L97-L115
|
| 285 |
+
output_imgs = []
|
| 286 |
+
output_len = []
|
| 287 |
+
for [h, w] in shapes:
|
| 288 |
+
B_ = h * w
|
| 289 |
+
if self.add_global:
|
| 290 |
+
output_imgs.append(qformer_features[:B_+1].view(1, -1, q_C))
|
| 291 |
+
qformer_features = qformer_features[B_+1:]
|
| 292 |
+
else:
|
| 293 |
+
output_imgs.append(qformer_features[:B_].view(1, -1, q_C))
|
| 294 |
+
qformer_features = qformer_features[B_:]
|
| 295 |
+
# logger.info(f"Features shape: {output_imgs[-1].shape}")
|
| 296 |
+
output_len.append(output_imgs[-1].shape[1])
|
| 297 |
+
|
| 298 |
+
return output_imgs, output_len, use_image
|
| 299 |
+
|
| 300 |
+
def _get_text_len(self, text):
|
| 301 |
+
return self.mistral_tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.shape[1]
|
| 302 |
+
|
| 303 |
+
def forward(self, image, text_input, instruction):
|
| 304 |
+
if len(image[0].shape) == 1:
|
| 305 |
+
use_text = True
|
| 306 |
+
device = image[0].device
|
| 307 |
+
batch_size = len(image)
|
| 308 |
+
img_lens = [0] * batch_size
|
| 309 |
+
else:
|
| 310 |
+
use_text = False
|
| 311 |
+
img_embeds, img_lens, use_image = self.encode_img(image, instruction)
|
| 312 |
+
device = img_embeds[0].device
|
| 313 |
+
batch_size = len(img_embeds)
|
| 314 |
+
|
| 315 |
+
# mark the largest length
|
| 316 |
+
# when padding, the attention mask will be 0
|
| 317 |
+
max_len = 0
|
| 318 |
+
input_embed_list = []
|
| 319 |
+
p_before_len_list = []
|
| 320 |
+
target_list = []
|
| 321 |
+
# handle each prompt individually
|
| 322 |
+
for idx, prompt in enumerate(text_input):
|
| 323 |
+
if use_text:
|
| 324 |
+
p_after = prompt
|
| 325 |
+
p_after_tokens = self.mistral_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(device)
|
| 326 |
+
if self.use_lora:
|
| 327 |
+
p_after_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_after_tokens.input_ids)
|
| 328 |
+
else:
|
| 329 |
+
p_after_embeds = self.mistral_model.model.embed_tokens(p_after_tokens.input_ids)
|
| 330 |
+
input_embeds = p_after_embeds
|
| 331 |
+
else:
|
| 332 |
+
tmp_img_embeds = img_embeds[idx]
|
| 333 |
+
# split the prompt via END_TOKEN
|
| 334 |
+
end_token = self.img_end_token if use_image else self.end_token
|
| 335 |
+
p_before, p_after = prompt.split(end_token)
|
| 336 |
+
p_after = end_token + p_after
|
| 337 |
+
p_before_tokens = self.mistral_tokenizer(p_before, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device)
|
| 338 |
+
p_after_tokens = self.mistral_tokenizer(p_after, return_tensors="pt", add_special_tokens=False).to(tmp_img_embeds.device)
|
| 339 |
+
if self.use_lora:
|
| 340 |
+
p_before_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids)
|
| 341 |
+
p_after_embeds = self.mistral_model.base_model.model.model.embed_tokens(p_after_tokens.input_ids)
|
| 342 |
+
else:
|
| 343 |
+
p_before_embeds = self.mistral_model.model.embed_tokens(p_before_tokens.input_ids)
|
| 344 |
+
p_after_embeds = self.mistral_model.model.embed_tokens(p_after_tokens.input_ids)
|
| 345 |
+
input_embeds = torch.cat([p_before_embeds, tmp_img_embeds, p_after_embeds], dim=1)
|
| 346 |
+
|
| 347 |
+
# extract the answers and mask the target
|
| 348 |
+
# the answers are only in the p_after
|
| 349 |
+
sep1 = self.human_start + " "
|
| 350 |
+
sep2 = " " + self.human_end + " "
|
| 351 |
+
raw_text = p_after.split(sep2)
|
| 352 |
+
for idx in range(0, len(raw_text) - 1):
|
| 353 |
+
raw_text[idx] = raw_text[idx] + sep2
|
| 354 |
+
# the first raw_text contains system and question
|
| 355 |
+
# the last raw_text only contains answer
|
| 356 |
+
# rstrip() for the extra " "
|
| 357 |
+
answer_targets = p_after_tokens.input_ids.clone()
|
| 358 |
+
# [target] "xxxxx. </s>"
|
| 359 |
+
cur_len = self._get_text_len(raw_text[0].rstrip())
|
| 360 |
+
answer_targets[:, :cur_len] = -100
|
| 361 |
+
for text in raw_text[1:-1]:
|
| 362 |
+
total_len = self._get_text_len(text.rstrip())
|
| 363 |
+
ans_len = self._get_text_len((text.split(sep1)[0]).rstrip())
|
| 364 |
+
answer_targets[:, (cur_len+ans_len):(cur_len+total_len)] = -100
|
| 365 |
+
cur_len += total_len
|
| 366 |
+
cur_len += self._get_text_len(raw_text[-1].rstrip())
|
| 367 |
+
|
| 368 |
+
if self.debug: # Inspect and check the correctness of masking
|
| 369 |
+
z = answer_targets[0].clone()
|
| 370 |
+
z = torch.where(z == -100, self.mistral_tokenizer.unk_token_id, z)
|
| 371 |
+
logger.info(self.mistral_tokenizer.decode(z))
|
| 372 |
+
|
| 373 |
+
assert cur_len == answer_targets.shape[1], f"The final length ({cur_len}) is not equal to the original prompt ({answer_targets.shape[1]}): {prompt}"
|
| 374 |
+
|
| 375 |
+
max_len = max(max_len, input_embeds.shape[1])
|
| 376 |
+
input_embed_list.append(input_embeds)
|
| 377 |
+
if use_text:
|
| 378 |
+
p_before_len_list.append(0)
|
| 379 |
+
else:
|
| 380 |
+
p_before_len_list.append(p_before_tokens.input_ids.shape[1])
|
| 381 |
+
target_list.append(answer_targets)
|
| 382 |
+
|
| 383 |
+
# plus one for bos
|
| 384 |
+
# max_txt_len plus num_query_token is the max len
|
| 385 |
+
txt_len = min(max_len + 1, self.max_txt_len + max(img_lens))
|
| 386 |
+
inputs_embeds = torch.ones([batch_size, txt_len], dtype=torch.long).to(device) * self.mistral_tokenizer.pad_token_id
|
| 387 |
+
if self.use_lora:
|
| 388 |
+
inputs_embeds = self.mistral_model.base_model.model.model.embed_tokens(inputs_embeds)
|
| 389 |
+
else:
|
| 390 |
+
inputs_embeds = self.mistral_model.model.embed_tokens(inputs_embeds)
|
| 391 |
+
attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(device)
|
| 392 |
+
targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(device).fill_(-100)
|
| 393 |
+
# set bos_token
|
| 394 |
+
inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id
|
| 395 |
+
|
| 396 |
+
for idx in range(batch_size):
|
| 397 |
+
input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
|
| 398 |
+
# if less than txt_len, the input will be padding
|
| 399 |
+
# if more than txt_len, the input will be truncated
|
| 400 |
+
inputs_embeds[idx, 1:(input_len+1)] = input_embed_list[idx][:, :input_len]
|
| 401 |
+
# the attention_mask is 0 when padding
|
| 402 |
+
attention_mask[idx, :(input_len+1)] = 1
|
| 403 |
+
# the target is -100 when padding
|
| 404 |
+
p_before_len = p_before_len_list[idx]
|
| 405 |
+
targets[idx, (p_before_len+img_lens[idx]+1):(input_len+1)] = target_list[idx][0, :(input_len-p_before_len-img_lens[idx])]
|
| 406 |
+
|
| 407 |
+
with self.maybe_autocast():
|
| 408 |
+
outputs = self.mistral_model(
|
| 409 |
+
inputs_embeds=inputs_embeds,
|
| 410 |
+
attention_mask=attention_mask,
|
| 411 |
+
return_dict=True,
|
| 412 |
+
labels=targets,
|
| 413 |
+
use_cache=False, # current flash_attn2 dows not support padding=right for mistral
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return dict(
|
| 417 |
+
loss=outputs.loss,
|
| 418 |
+
)
|
vit.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.utils.checkpoint as checkpoint
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _cfg(url='', **kwargs):
|
| 15 |
+
return {
|
| 16 |
+
'url': url,
|
| 17 |
+
'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 18 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
| 19 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
| 20 |
+
**kwargs
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DropPath(nn.Module):
|
| 25 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 26 |
+
"""
|
| 27 |
+
def __init__(self, drop_prob=None):
|
| 28 |
+
super(DropPath, self).__init__()
|
| 29 |
+
self.drop_prob = drop_prob
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 33 |
+
|
| 34 |
+
def extra_repr(self) -> str:
|
| 35 |
+
return 'p={}'.format(self.drop_prob)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Mlp(nn.Module):
|
| 39 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 40 |
+
super().__init__()
|
| 41 |
+
out_features = out_features or in_features
|
| 42 |
+
hidden_features = hidden_features or in_features
|
| 43 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 44 |
+
self.act = act_layer()
|
| 45 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 46 |
+
self.drop = nn.Dropout(drop)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.fc1(x)
|
| 50 |
+
x = self.act(x)
|
| 51 |
+
x = self.drop(x)
|
| 52 |
+
x = self.fc2(x)
|
| 53 |
+
x = self.drop(x)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Attention(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 60 |
+
proj_drop=0., attn_head_dim=None):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.num_heads = num_heads
|
| 63 |
+
head_dim = dim // num_heads
|
| 64 |
+
if attn_head_dim is not None:
|
| 65 |
+
head_dim = attn_head_dim
|
| 66 |
+
all_head_dim = head_dim * self.num_heads
|
| 67 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 68 |
+
|
| 69 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 70 |
+
if qkv_bias:
|
| 71 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 72 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 73 |
+
else:
|
| 74 |
+
self.q_bias = None
|
| 75 |
+
self.v_bias = None
|
| 76 |
+
|
| 77 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 78 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 79 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
B, N, C = x.shape
|
| 83 |
+
qkv_bias = None
|
| 84 |
+
if self.q_bias is not None:
|
| 85 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
| 86 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 87 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 88 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 89 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 90 |
+
|
| 91 |
+
q = q * self.scale
|
| 92 |
+
attn = (q @ k.transpose(-2, -1))
|
| 93 |
+
|
| 94 |
+
attn = attn.softmax(dim=-1)
|
| 95 |
+
attn = self.attn_drop(attn)
|
| 96 |
+
|
| 97 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.proj_drop(x)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Block(nn.Module):
|
| 104 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 105 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
| 106 |
+
attn_head_dim=None):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.norm1 = norm_layer(dim)
|
| 109 |
+
self.attn = Attention(
|
| 110 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 111 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)
|
| 112 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 113 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 114 |
+
self.norm2 = norm_layer(dim)
|
| 115 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 116 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 117 |
+
|
| 118 |
+
if init_values > 0:
|
| 119 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 120 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
| 121 |
+
else:
|
| 122 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
if self.gamma_1 is None:
|
| 126 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 127 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 128 |
+
else:
|
| 129 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
| 130 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class PatchEmbed(nn.Module):
|
| 135 |
+
""" Image to Patch Embedding
|
| 136 |
+
"""
|
| 137 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
|
| 138 |
+
super().__init__()
|
| 139 |
+
img_size = to_2tuple(img_size)
|
| 140 |
+
patch_size = to_2tuple(patch_size)
|
| 141 |
+
self.tubelet_size = int(tubelet_size)
|
| 142 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
|
| 143 |
+
self.img_size = img_size
|
| 144 |
+
self.patch_size = patch_size
|
| 145 |
+
self.num_patches = num_patches
|
| 146 |
+
self.proj = nn.Conv3d(
|
| 147 |
+
in_channels=in_chans, out_channels=embed_dim,
|
| 148 |
+
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
|
| 149 |
+
stride=(self.tubelet_size, patch_size[0], patch_size[1])
|
| 150 |
+
)
|
| 151 |
+
logger.info(f'Num of patches: {num_patches}')
|
| 152 |
+
|
| 153 |
+
def forward(self, x, **kwargs):
|
| 154 |
+
B, C, T, H, W = x.shape
|
| 155 |
+
# FIXME look at relaxing size constraints
|
| 156 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
| 157 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 158 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
# sin-cos position encoding
|
| 162 |
+
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
| 163 |
+
def get_sinusoid_encoding_table(n_position, d_hid, ckpt_num_frame=-1, cur_frame=12):
|
| 164 |
+
''' Sinusoid position encoding table '''
|
| 165 |
+
# TODO: make it with torch instead of numpy
|
| 166 |
+
def get_position_angle_vec(position):
|
| 167 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
| 168 |
+
|
| 169 |
+
if ckpt_num_frame != -1 and ckpt_num_frame != cur_frame:
|
| 170 |
+
logger.info(f"Interpolate position embedding")
|
| 171 |
+
logger.info(f"Testing frame: {cur_frame}")
|
| 172 |
+
logger.info(f"Checkpoint frame: {ckpt_num_frame}")
|
| 173 |
+
|
| 174 |
+
T = ckpt_num_frame # checkpoint frame
|
| 175 |
+
new_T = cur_frame # testing frame
|
| 176 |
+
n_position = n_position // new_T * T # generate checkpoint position embedding
|
| 177 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
| 178 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
| 179 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
| 180 |
+
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
| 181 |
+
# interpolate
|
| 182 |
+
P = int((n_position // T) ** 0.5)
|
| 183 |
+
C = d_hid
|
| 184 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
| 185 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
|
| 186 |
+
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
|
| 187 |
+
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
|
| 188 |
+
sinusoid_table = sinusoid_table.flatten(1, 3)
|
| 189 |
+
return sinusoid_table
|
| 190 |
+
else:
|
| 191 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
| 192 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
| 193 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
| 194 |
+
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_sinusoid_encoding_table2(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
|
| 198 |
+
''' Sinusoid position encoding table '''
|
| 199 |
+
# TODO: make it with torch instead of numpy
|
| 200 |
+
def get_position_angle_vec(position):
|
| 201 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
| 202 |
+
|
| 203 |
+
# generate checkpoint position embedding
|
| 204 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
|
| 205 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
| 206 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
| 207 |
+
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
| 208 |
+
|
| 209 |
+
print(f"n_position: {n_position}")
|
| 210 |
+
print(f"pre_n_position: {pre_n_position}")
|
| 211 |
+
|
| 212 |
+
if n_position != pre_n_position:
|
| 213 |
+
T = ckpt_num_frame # checkpoint frame
|
| 214 |
+
P = 14 # checkpoint size
|
| 215 |
+
C = d_hid
|
| 216 |
+
new_P = int((n_position // cur_frame) ** 0.5) # testing size
|
| 217 |
+
print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
|
| 218 |
+
print(f'Interpolate the position embedding')
|
| 219 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
| 220 |
+
sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
|
| 221 |
+
sinusoid_table = torch.nn.functional.interpolate(
|
| 222 |
+
sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
|
| 223 |
+
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
|
| 224 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
|
| 225 |
+
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
| 226 |
+
|
| 227 |
+
if cur_frame != ckpt_num_frame:
|
| 228 |
+
print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
|
| 229 |
+
print(f'Interpolate the position embedding')
|
| 230 |
+
T = ckpt_num_frame # checkpoint frame
|
| 231 |
+
new_T = cur_frame # testing frame
|
| 232 |
+
# interpolate
|
| 233 |
+
P = int((n_position // cur_frame) ** 0.5) # testing size
|
| 234 |
+
C = d_hid
|
| 235 |
+
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
| 236 |
+
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
|
| 237 |
+
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
|
| 238 |
+
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
|
| 239 |
+
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
| 240 |
+
|
| 241 |
+
return sinusoid_table
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class PretrainVisionTransformerEncoder(nn.Module):
|
| 245 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 246 |
+
"""
|
| 247 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12,
|
| 248 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 249 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_frames=8, tubelet_size=1,
|
| 250 |
+
use_learnable_pos_emb=False,
|
| 251 |
+
use_checkpoint=False, checkpoint_num=0,
|
| 252 |
+
ckpt_num_frame=-1, with_ln=True, return_index=-1
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 256 |
+
self.patch_embed = PatchEmbed(
|
| 257 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
| 258 |
+
num_frames=num_frames, tubelet_size=tubelet_size
|
| 259 |
+
)
|
| 260 |
+
num_patches = self.patch_embed.num_patches
|
| 261 |
+
self.depth = depth + return_index + 1
|
| 262 |
+
self.use_checkpoint = use_checkpoint
|
| 263 |
+
self.checkpoint_num = checkpoint_num
|
| 264 |
+
logger.info(f"Use checkpoint: {use_checkpoint}")
|
| 265 |
+
logger.info(f"Checkpoint number: {checkpoint_num}")
|
| 266 |
+
logger.info(f"Real runing depth: {self.depth}")
|
| 267 |
+
|
| 268 |
+
# TODO: Add the cls token
|
| 269 |
+
if use_learnable_pos_emb:
|
| 270 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 271 |
+
self.img_pos_embed = nn.Parameter(torch.zeros(1, num_patches//(num_frames//tubelet_size) + 1, embed_dim))
|
| 272 |
+
else:
|
| 273 |
+
# sine-cosine positional embeddings
|
| 274 |
+
if img_size != 224:
|
| 275 |
+
self.pos_embed = get_sinusoid_encoding_table2(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
|
| 276 |
+
self.img_pos_embed = get_sinusoid_encoding_table2(num_patches//(num_frames//tubelet_size), embed_dim, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14)
|
| 277 |
+
else:
|
| 278 |
+
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size)
|
| 279 |
+
self.img_pos_embed = get_sinusoid_encoding_table(num_patches//(num_frames//tubelet_size), embed_dim)
|
| 280 |
+
|
| 281 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 282 |
+
self.blocks = nn.ModuleList([
|
| 283 |
+
Block(
|
| 284 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 285 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 286 |
+
init_values=init_values)
|
| 287 |
+
for i in range(self.depth)])
|
| 288 |
+
|
| 289 |
+
if with_ln:
|
| 290 |
+
self.norm = norm_layer(embed_dim)
|
| 291 |
+
else:
|
| 292 |
+
self.norm = nn.Identity()
|
| 293 |
+
|
| 294 |
+
if use_learnable_pos_emb:
|
| 295 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 296 |
+
|
| 297 |
+
@torch.jit.ignore
|
| 298 |
+
def no_weight_decay(self):
|
| 299 |
+
return {'pos_embed', 'cls_token'}
|
| 300 |
+
|
| 301 |
+
def forward_features(self, x, use_image=False):
|
| 302 |
+
x = self.patch_embed(x)
|
| 303 |
+
|
| 304 |
+
if use_image:
|
| 305 |
+
x = x + self.img_pos_embed.type_as(x).to(x.device).clone().detach()
|
| 306 |
+
else:
|
| 307 |
+
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
|
| 308 |
+
|
| 309 |
+
B, _, C = x.shape
|
| 310 |
+
x_vis = x
|
| 311 |
+
|
| 312 |
+
for idx, blk in enumerate(self.blocks):
|
| 313 |
+
if self.use_checkpoint and idx < self.checkpoint_num:
|
| 314 |
+
x_vis = checkpoint.checkpoint(blk, x_vis)
|
| 315 |
+
else:
|
| 316 |
+
x_vis = blk(x_vis)
|
| 317 |
+
|
| 318 |
+
# with ln ot not
|
| 319 |
+
x_vis = self.norm(x_vis)
|
| 320 |
+
return x_vis
|
| 321 |
+
|
| 322 |
+
def forward(self, x, use_image=False):
|
| 323 |
+
x_vis = self.forward_features(x, use_image)
|
| 324 |
+
return x_vis
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class PretrainVisionTransformer(nn.Module):
|
| 328 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 329 |
+
"""
|
| 330 |
+
def __init__(self,
|
| 331 |
+
img_size=224,
|
| 332 |
+
patch_size=16,
|
| 333 |
+
encoder_in_chans=3,
|
| 334 |
+
encoder_embed_dim=768,
|
| 335 |
+
encoder_depth=12,
|
| 336 |
+
encoder_num_heads=12,
|
| 337 |
+
mlp_ratio=4.,
|
| 338 |
+
qkv_bias=True,
|
| 339 |
+
qk_scale=None,
|
| 340 |
+
drop_rate=0.,
|
| 341 |
+
attn_drop_rate=0.,
|
| 342 |
+
drop_path_rate=0.,
|
| 343 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 344 |
+
init_values=0.,
|
| 345 |
+
use_learnable_pos_emb=False,
|
| 346 |
+
num_frames=8,
|
| 347 |
+
tubelet_size=1,
|
| 348 |
+
use_checkpoint=False,
|
| 349 |
+
checkpoint_num=0,
|
| 350 |
+
ckpt_num_frame=4, # the pretrained model uses 4 frames
|
| 351 |
+
return_index=-1,
|
| 352 |
+
with_ln=False
|
| 353 |
+
):
|
| 354 |
+
super().__init__()
|
| 355 |
+
|
| 356 |
+
self.encoder = PretrainVisionTransformerEncoder(
|
| 357 |
+
img_size=img_size,
|
| 358 |
+
patch_size=patch_size,
|
| 359 |
+
in_chans=encoder_in_chans,
|
| 360 |
+
embed_dim=encoder_embed_dim,
|
| 361 |
+
depth=encoder_depth,
|
| 362 |
+
num_heads=encoder_num_heads,
|
| 363 |
+
mlp_ratio=mlp_ratio,
|
| 364 |
+
qkv_bias=qkv_bias,
|
| 365 |
+
qk_scale=qk_scale,
|
| 366 |
+
drop_rate=drop_rate,
|
| 367 |
+
attn_drop_rate=attn_drop_rate,
|
| 368 |
+
drop_path_rate=drop_path_rate,
|
| 369 |
+
norm_layer=norm_layer,
|
| 370 |
+
init_values=init_values,
|
| 371 |
+
num_frames=num_frames,
|
| 372 |
+
tubelet_size=tubelet_size,
|
| 373 |
+
use_learnable_pos_emb=use_learnable_pos_emb,
|
| 374 |
+
use_checkpoint=use_checkpoint,
|
| 375 |
+
checkpoint_num=checkpoint_num,
|
| 376 |
+
ckpt_num_frame=ckpt_num_frame,
|
| 377 |
+
with_ln=with_ln,
|
| 378 |
+
return_index=return_index
|
| 379 |
+
)
|
| 380 |
+
logger.info(f'With LN: {with_ln}')
|
| 381 |
+
logger.info(f'Total {encoder_depth} layer')
|
| 382 |
+
logger.info(f'Return {encoder_depth+return_index+1}-th layer')
|
| 383 |
+
|
| 384 |
+
self.apply(self._init_weights)
|
| 385 |
+
|
| 386 |
+
def _init_weights(self, m):
|
| 387 |
+
if isinstance(m, nn.Linear):
|
| 388 |
+
nn.init.xavier_uniform_(m.weight)
|
| 389 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 390 |
+
nn.init.constant_(m.bias, 0)
|
| 391 |
+
elif isinstance(m, nn.LayerNorm):
|
| 392 |
+
nn.init.constant_(m.bias, 0)
|
| 393 |
+
nn.init.constant_(m.weight, 1.0)
|
| 394 |
+
|
| 395 |
+
@torch.jit.ignore
|
| 396 |
+
def no_weight_decay(self):
|
| 397 |
+
return {'pos_embed', 'cls_token', 'clip_pos_embed'}
|
| 398 |
+
|
| 399 |
+
def forward(self, x, use_image=False):
|
| 400 |
+
T = x.shape[2]
|
| 401 |
+
x_vis = self.encoder(x, use_image) # [B, N_vis, C_e]
|
| 402 |
+
B, TL, C = x_vis.shape
|
| 403 |
+
x_vis = x_vis.view(B, T, TL // T, C)
|
| 404 |
+
|
| 405 |
+
return x_vis
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def build_vit(config):
|
| 409 |
+
model = PretrainVisionTransformer(
|
| 410 |
+
img_size=config.vision_encoder.img_size,
|
| 411 |
+
patch_size=config.vision_encoder.patch_size,
|
| 412 |
+
encoder_embed_dim=config.vision_encoder.encoder_embed_dim,
|
| 413 |
+
encoder_depth=config.vision_encoder.encoder_depth,
|
| 414 |
+
encoder_num_heads=config.vision_encoder.encoder_num_heads,
|
| 415 |
+
drop_path_rate=config.vision_encoder.drop_path_rate,
|
| 416 |
+
num_frames=config.vision_encoder.num_frames,
|
| 417 |
+
tubelet_size=config.vision_encoder.tubelet_size,
|
| 418 |
+
use_checkpoint=config.vision_encoder.use_checkpoint,
|
| 419 |
+
checkpoint_num=config.vision_encoder.checkpoint_num,
|
| 420 |
+
return_index=config.vision_encoder.get('return_index', -1),
|
| 421 |
+
with_ln=config.vision_encoder.get('with_ln', False),
|
| 422 |
+
)
|
| 423 |
+
model.default_cfg = _cfg()
|
| 424 |
+
if config.vision_encoder.pretrained:
|
| 425 |
+
logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}")
|
| 426 |
+
state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu')
|
| 427 |
+
model.load_state_dict(state_dict, strict=False)
|
| 428 |
+
else:
|
| 429 |
+
logger.info("No pretrained weights!!!")
|
| 430 |
+
return model
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if __name__ == '__main__':
|
| 434 |
+
import time
|
| 435 |
+
from fvcore.nn import FlopCountAnalysis
|
| 436 |
+
from fvcore.nn import flop_count_table
|
| 437 |
+
import numpy as np
|
| 438 |
+
|
| 439 |
+
seed = 4217
|
| 440 |
+
np.random.seed(seed)
|
| 441 |
+
torch.manual_seed(seed)
|
| 442 |
+
torch.cuda.manual_seed(seed)
|
| 443 |
+
torch.cuda.manual_seed_all(seed)
|
| 444 |
+
num_frames = 4
|
| 445 |
+
|
| 446 |
+
config = {
|
| 447 |
+
'vision_encoder':
|
| 448 |
+
{
|
| 449 |
+
'img_size': 224,
|
| 450 |
+
'patch_size': 16,
|
| 451 |
+
'encoder_embed_dim': 768,
|
| 452 |
+
'encoder_depth': 12,
|
| 453 |
+
'encoder_num_heads': 12,
|
| 454 |
+
'drop_path_rate': 0.1,
|
| 455 |
+
'num_frames': num_frames,
|
| 456 |
+
'tubelet_size': 1,
|
| 457 |
+
'use_checkpoint': False,
|
| 458 |
+
'checkpoint_num': 0,
|
| 459 |
+
'pretrained': 'your_model_path/l16_25m.pth',
|
| 460 |
+
'ckpt_num_frame': 8,
|
| 461 |
+
'return_index': -1,
|
| 462 |
+
'with_ln': False,
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
from easydict import EasyDict
|
| 466 |
+
model = build_vit(EasyDict(config))
|
| 467 |
+
|
| 468 |
+
# flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
|
| 469 |
+
# s = time.time()
|
| 470 |
+
# print(flop_count_table(flops, max_depth=1))
|
| 471 |
+
# print(time.time()-s)
|
| 472 |
+
print(model(torch.rand(1, 3, num_frames, 224, 224), False).shape)
|