kamp0010 commited on
Commit
baaced2
Β·
verified Β·
1 Parent(s): bd680a9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -18
main.py CHANGED
@@ -17,23 +17,20 @@ import numpy as np
17
  import faiss
18
 
19
  # ── Compatibility patches ──────────────────────────────────────────────────────
20
- # The jina-bert-v2 custom modeling code was written against an older transformers
21
- # API. Two things were removed / tightened in newer releases:
 
22
  #
23
- # 1. `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils.
24
- # 2. `PretrainedConfig` no longer sets is_decoder / add_cross_attention as instance
25
- # defaults in __init__. A tightened __getattribute__ now raises AttributeError
26
- # instead of the old silent fallback, breaking JinaBertConfig access patterns.
27
- #
28
- # Both patches are guarded with hasattr/flag checks so they are no-ops if a future
29
- # transformers version re-adds these symbols.
30
 
 
31
  import transformers.pytorch_utils as _pt_utils
32
  if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
33
  def _find_pruneable_heads_and_indices(
34
  heads, n_heads: int, head_size: int, already_pruned_heads
35
  ):
36
- import torch
37
  mask = torch.ones(n_heads, head_size)
38
  heads = set(heads) - already_pruned_heads
39
  for head in heads:
@@ -44,25 +41,51 @@ if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
44
  return heads, index
45
  _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
46
 
 
47
  import transformers.configuration_utils as _cfg_utils
48
  _PC = _cfg_utils.PretrainedConfig
49
  if not hasattr(_PC, "_jina_compat_patched"):
50
- # Attributes that used to be set in PretrainedConfig.__init__ with defaults
51
- # but were removed from the base class in newer transformers versions.
52
- _LEGACY_DEFAULTS = {
53
- "is_decoder": False,
54
- "add_cross_attention": False,
55
  "cross_attention_hidden_size": None,
56
- "use_cache": True,
57
  }
58
  def _pc_getattr(self, key: str):
59
- if key in _LEGACY_DEFAULTS:
60
- return _LEGACY_DEFAULTS[key]
61
  raise AttributeError(
62
  f"'{type(self).__name__}' object has no attribute '{key}'"
63
  )
64
  _PC.__getattr__ = _pc_getattr
65
  _PC._jina_compat_patched = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # ──────────────────────────────────────────────────────────────────────────────
67
 
68
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
 
17
  import faiss
18
 
19
  # ── Compatibility patches ──────────────────────────────────────────────────────
20
+ # jina-bert-v2 (trust_remote_code) was written against transformers 4.x.
21
+ # Transformers 5.x removed / broke three things the model relies on.
22
+ # All patches are no-ops when the symbol already exists.
23
  #
24
+ # 1. find_pruneable_heads_and_indices β€” removed from pytorch_utils
25
+ # 2. PretrainedConfig.is_decoder etc β€” no longer set as instance defaults
26
+ # 3. PreTrainedModel.get_head_mask β€” removed from modeling_utils in T5
 
 
 
 
27
 
28
+ # ── patch 1: pytorch_utils ────────────────────────────────────────────────────
29
  import transformers.pytorch_utils as _pt_utils
30
  if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
31
  def _find_pruneable_heads_and_indices(
32
  heads, n_heads: int, head_size: int, already_pruned_heads
33
  ):
 
34
  mask = torch.ones(n_heads, head_size)
35
  heads = set(heads) - already_pruned_heads
36
  for head in heads:
 
41
  return heads, index
42
  _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
43
 
44
+ # ── patch 2: PretrainedConfig legacy defaults ─────────────────────────────────
45
  import transformers.configuration_utils as _cfg_utils
46
  _PC = _cfg_utils.PretrainedConfig
47
  if not hasattr(_PC, "_jina_compat_patched"):
48
+ _LEGACY_CFG_DEFAULTS = {
49
+ "is_decoder": False,
50
+ "add_cross_attention": False,
 
 
51
  "cross_attention_hidden_size": None,
52
+ "use_cache": True,
53
  }
54
  def _pc_getattr(self, key: str):
55
+ if key in _LEGACY_CFG_DEFAULTS:
56
+ return _LEGACY_CFG_DEFAULTS[key]
57
  raise AttributeError(
58
  f"'{type(self).__name__}' object has no attribute '{key}'"
59
  )
60
  _PC.__getattr__ = _pc_getattr
61
  _PC._jina_compat_patched = True
62
+
63
+ # ── patch 3: PreTrainedModel.get_head_mask ────────────────────────────────────
64
+ import transformers.modeling_utils as _mod_utils
65
+ _PTM = _mod_utils.PreTrainedModel
66
+ if not hasattr(_PTM, "get_head_mask"):
67
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
68
+ if head_mask.dim() == 1:
69
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
70
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
71
+ elif head_mask.dim() == 2:
72
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
73
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
74
+ head_mask = head_mask.to(dtype=self.dtype)
75
+ return head_mask
76
+
77
+ def _get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
78
+ if head_mask is not None:
79
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
80
+ if is_attention_chunked:
81
+ head_mask = head_mask.unsqueeze(-1)
82
+ else:
83
+ head_mask = [None] * num_hidden_layers
84
+ return head_mask
85
+
86
+ if not hasattr(_PTM, "_convert_head_mask_to_5d"):
87
+ _PTM._convert_head_mask_to_5d = _convert_head_mask_to_5d
88
+ _PTM.get_head_mask = _get_head_mask
89
  # ──────────────────────────────────────────────────────────────────────────────
90
 
91
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form