Update main.py
Browse files
main.py
CHANGED
|
@@ -17,23 +17,20 @@ import numpy as np
|
|
| 17 |
import faiss
|
| 18 |
|
| 19 |
# ββ Compatibility patches ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
-
#
|
| 21 |
-
#
|
|
|
|
| 22 |
#
|
| 23 |
-
# 1.
|
| 24 |
-
# 2.
|
| 25 |
-
#
|
| 26 |
-
# instead of the old silent fallback, breaking JinaBertConfig access patterns.
|
| 27 |
-
#
|
| 28 |
-
# Both patches are guarded with hasattr/flag checks so they are no-ops if a future
|
| 29 |
-
# transformers version re-adds these symbols.
|
| 30 |
|
|
|
|
| 31 |
import transformers.pytorch_utils as _pt_utils
|
| 32 |
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
| 33 |
def _find_pruneable_heads_and_indices(
|
| 34 |
heads, n_heads: int, head_size: int, already_pruned_heads
|
| 35 |
):
|
| 36 |
-
import torch
|
| 37 |
mask = torch.ones(n_heads, head_size)
|
| 38 |
heads = set(heads) - already_pruned_heads
|
| 39 |
for head in heads:
|
|
@@ -44,25 +41,51 @@ if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
|
| 44 |
return heads, index
|
| 45 |
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
| 46 |
|
|
|
|
| 47 |
import transformers.configuration_utils as _cfg_utils
|
| 48 |
_PC = _cfg_utils.PretrainedConfig
|
| 49 |
if not hasattr(_PC, "_jina_compat_patched"):
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
"is_decoder": False,
|
| 54 |
-
"add_cross_attention": False,
|
| 55 |
"cross_attention_hidden_size": None,
|
| 56 |
-
"use_cache":
|
| 57 |
}
|
| 58 |
def _pc_getattr(self, key: str):
|
| 59 |
-
if key in
|
| 60 |
-
return
|
| 61 |
raise AttributeError(
|
| 62 |
f"'{type(self).__name__}' object has no attribute '{key}'"
|
| 63 |
)
|
| 64 |
_PC.__getattr__ = _pc_getattr
|
| 65 |
_PC._jina_compat_patched = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
|
| 68 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
|
|
| 17 |
import faiss
|
| 18 |
|
| 19 |
# ββ Compatibility patches ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
# jina-bert-v2 (trust_remote_code) was written against transformers 4.x.
|
| 21 |
+
# Transformers 5.x removed / broke three things the model relies on.
|
| 22 |
+
# All patches are no-ops when the symbol already exists.
|
| 23 |
#
|
| 24 |
+
# 1. find_pruneable_heads_and_indices β removed from pytorch_utils
|
| 25 |
+
# 2. PretrainedConfig.is_decoder etc β no longer set as instance defaults
|
| 26 |
+
# 3. PreTrainedModel.get_head_mask β removed from modeling_utils in T5
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# ββ patch 1: pytorch_utils ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
import transformers.pytorch_utils as _pt_utils
|
| 30 |
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
| 31 |
def _find_pruneable_heads_and_indices(
|
| 32 |
heads, n_heads: int, head_size: int, already_pruned_heads
|
| 33 |
):
|
|
|
|
| 34 |
mask = torch.ones(n_heads, head_size)
|
| 35 |
heads = set(heads) - already_pruned_heads
|
| 36 |
for head in heads:
|
|
|
|
| 41 |
return heads, index
|
| 42 |
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
| 43 |
|
| 44 |
+
# ββ patch 2: PretrainedConfig legacy defaults βββββββββββββββββββββββββββββββββ
|
| 45 |
import transformers.configuration_utils as _cfg_utils
|
| 46 |
_PC = _cfg_utils.PretrainedConfig
|
| 47 |
if not hasattr(_PC, "_jina_compat_patched"):
|
| 48 |
+
_LEGACY_CFG_DEFAULTS = {
|
| 49 |
+
"is_decoder": False,
|
| 50 |
+
"add_cross_attention": False,
|
|
|
|
|
|
|
| 51 |
"cross_attention_hidden_size": None,
|
| 52 |
+
"use_cache": True,
|
| 53 |
}
|
| 54 |
def _pc_getattr(self, key: str):
|
| 55 |
+
if key in _LEGACY_CFG_DEFAULTS:
|
| 56 |
+
return _LEGACY_CFG_DEFAULTS[key]
|
| 57 |
raise AttributeError(
|
| 58 |
f"'{type(self).__name__}' object has no attribute '{key}'"
|
| 59 |
)
|
| 60 |
_PC.__getattr__ = _pc_getattr
|
| 61 |
_PC._jina_compat_patched = True
|
| 62 |
+
|
| 63 |
+
# ββ patch 3: PreTrainedModel.get_head_mask ββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
import transformers.modeling_utils as _mod_utils
|
| 65 |
+
_PTM = _mod_utils.PreTrainedModel
|
| 66 |
+
if not hasattr(_PTM, "get_head_mask"):
|
| 67 |
+
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
| 68 |
+
if head_mask.dim() == 1:
|
| 69 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 70 |
+
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
| 71 |
+
elif head_mask.dim() == 2:
|
| 72 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
| 73 |
+
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
| 74 |
+
head_mask = head_mask.to(dtype=self.dtype)
|
| 75 |
+
return head_mask
|
| 76 |
+
|
| 77 |
+
def _get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
|
| 78 |
+
if head_mask is not None:
|
| 79 |
+
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
| 80 |
+
if is_attention_chunked:
|
| 81 |
+
head_mask = head_mask.unsqueeze(-1)
|
| 82 |
+
else:
|
| 83 |
+
head_mask = [None] * num_hidden_layers
|
| 84 |
+
return head_mask
|
| 85 |
+
|
| 86 |
+
if not hasattr(_PTM, "_convert_head_mask_to_5d"):
|
| 87 |
+
_PTM._convert_head_mask_to_5d = _convert_head_mask_to_5d
|
| 88 |
+
_PTM.get_head_mask = _get_head_mask
|
| 89 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
|
| 91 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|