Update main.py
Browse files
main.py
CHANGED
|
@@ -16,28 +16,53 @@ import torch
|
|
| 16 |
import numpy as np
|
| 17 |
import faiss
|
| 18 |
|
| 19 |
-
# ββ Compatibility
|
| 20 |
-
#
|
| 21 |
-
#
|
| 22 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
import transformers.pytorch_utils as _pt_utils
|
| 24 |
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
| 25 |
-
from typing import Set, List, Tuple
|
| 26 |
def _find_pruneable_heads_and_indices(
|
| 27 |
-
heads:
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
already_pruned_heads: Set[int],
|
| 31 |
-
) -> Tuple[Set[int], torch.LongTensor]:
|
| 32 |
mask = torch.ones(n_heads, head_size)
|
| 33 |
heads = set(heads) - already_pruned_heads
|
| 34 |
for head in heads:
|
| 35 |
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 36 |
mask[head] = 0
|
| 37 |
mask = mask.view(-1).contiguous().eq(1)
|
| 38 |
-
index
|
| 39 |
return heads, index
|
| 40 |
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
|
| 43 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
import faiss
|
| 18 |
|
| 19 |
+
# ββ Compatibility patches ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
# The jina-bert-v2 custom modeling code was written against an older transformers
|
| 21 |
+
# API. Two things were removed / tightened in newer releases:
|
| 22 |
+
#
|
| 23 |
+
# 1. `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils.
|
| 24 |
+
# 2. `PretrainedConfig` no longer sets is_decoder / add_cross_attention as instance
|
| 25 |
+
# defaults in __init__. A tightened __getattribute__ now raises AttributeError
|
| 26 |
+
# instead of the old silent fallback, breaking JinaBertConfig access patterns.
|
| 27 |
+
#
|
| 28 |
+
# Both patches are guarded with hasattr/flag checks so they are no-ops if a future
|
| 29 |
+
# transformers version re-adds these symbols.
|
| 30 |
+
|
| 31 |
import transformers.pytorch_utils as _pt_utils
|
| 32 |
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
|
|
|
| 33 |
def _find_pruneable_heads_and_indices(
|
| 34 |
+
heads, n_heads: int, head_size: int, already_pruned_heads
|
| 35 |
+
):
|
| 36 |
+
import torch
|
|
|
|
|
|
|
| 37 |
mask = torch.ones(n_heads, head_size)
|
| 38 |
heads = set(heads) - already_pruned_heads
|
| 39 |
for head in heads:
|
| 40 |
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 41 |
mask[head] = 0
|
| 42 |
mask = mask.view(-1).contiguous().eq(1)
|
| 43 |
+
index = torch.arange(len(mask))[mask].long()
|
| 44 |
return heads, index
|
| 45 |
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
| 46 |
+
|
| 47 |
+
import transformers.configuration_utils as _cfg_utils
|
| 48 |
+
_PC = _cfg_utils.PretrainedConfig
|
| 49 |
+
if not hasattr(_PC, "_jina_compat_patched"):
|
| 50 |
+
# Attributes that used to be set in PretrainedConfig.__init__ with defaults
|
| 51 |
+
# but were removed from the base class in newer transformers versions.
|
| 52 |
+
_LEGACY_DEFAULTS = {
|
| 53 |
+
"is_decoder": False,
|
| 54 |
+
"add_cross_attention": False,
|
| 55 |
+
"cross_attention_hidden_size": None,
|
| 56 |
+
"use_cache": True,
|
| 57 |
+
}
|
| 58 |
+
def _pc_getattr(self, key: str):
|
| 59 |
+
if key in _LEGACY_DEFAULTS:
|
| 60 |
+
return _LEGACY_DEFAULTS[key]
|
| 61 |
+
raise AttributeError(
|
| 62 |
+
f"'{type(self).__name__}' object has no attribute '{key}'"
|
| 63 |
+
)
|
| 64 |
+
_PC.__getattr__ = _pc_getattr
|
| 65 |
+
_PC._jina_compat_patched = True
|
| 66 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
|
| 68 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|