kamp0010 commited on
Commit
216005d
Β·
verified Β·
1 Parent(s): 1e4d629

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +36 -11
main.py CHANGED
@@ -16,28 +16,53 @@ import torch
16
  import numpy as np
17
  import faiss
18
 
19
- # ── Compatibility patch ────────────────────────────────────────────────────────
20
- # `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils
21
- # in newer releases, but the jina-bert-v2 custom modeling code still imports it
22
- # from there. We re-inject the original implementation before anything loads.
 
 
 
 
 
 
 
 
23
  import transformers.pytorch_utils as _pt_utils
24
  if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
25
- from typing import Set, List, Tuple
26
  def _find_pruneable_heads_and_indices(
27
- heads: List[int],
28
- n_heads: int,
29
- head_size: int,
30
- already_pruned_heads: Set[int],
31
- ) -> Tuple[Set[int], torch.LongTensor]:
32
  mask = torch.ones(n_heads, head_size)
33
  heads = set(heads) - already_pruned_heads
34
  for head in heads:
35
  head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
36
  mask[head] = 0
37
  mask = mask.view(-1).contiguous().eq(1)
38
- index: torch.LongTensor = torch.arange(len(mask))[mask].long()
39
  return heads, index
40
  _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # ──────────────────────────────────────────────────────────────────────────────
42
 
43
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
 
16
  import numpy as np
17
  import faiss
18
 
19
+ # ── Compatibility patches ──────────────────────────────────────────────────────
20
+ # The jina-bert-v2 custom modeling code was written against an older transformers
21
+ # API. Two things were removed / tightened in newer releases:
22
+ #
23
+ # 1. `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils.
24
+ # 2. `PretrainedConfig` no longer sets is_decoder / add_cross_attention as instance
25
+ # defaults in __init__. A tightened __getattribute__ now raises AttributeError
26
+ # instead of the old silent fallback, breaking JinaBertConfig access patterns.
27
+ #
28
+ # Both patches are guarded with hasattr/flag checks so they are no-ops if a future
29
+ # transformers version re-adds these symbols.
30
+
31
  import transformers.pytorch_utils as _pt_utils
32
  if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
 
33
  def _find_pruneable_heads_and_indices(
34
+ heads, n_heads: int, head_size: int, already_pruned_heads
35
+ ):
36
+ import torch
 
 
37
  mask = torch.ones(n_heads, head_size)
38
  heads = set(heads) - already_pruned_heads
39
  for head in heads:
40
  head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
41
  mask[head] = 0
42
  mask = mask.view(-1).contiguous().eq(1)
43
+ index = torch.arange(len(mask))[mask].long()
44
  return heads, index
45
  _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
46
+
47
+ import transformers.configuration_utils as _cfg_utils
48
+ _PC = _cfg_utils.PretrainedConfig
49
+ if not hasattr(_PC, "_jina_compat_patched"):
50
+ # Attributes that used to be set in PretrainedConfig.__init__ with defaults
51
+ # but were removed from the base class in newer transformers versions.
52
+ _LEGACY_DEFAULTS = {
53
+ "is_decoder": False,
54
+ "add_cross_attention": False,
55
+ "cross_attention_hidden_size": None,
56
+ "use_cache": True,
57
+ }
58
+ def _pc_getattr(self, key: str):
59
+ if key in _LEGACY_DEFAULTS:
60
+ return _LEGACY_DEFAULTS[key]
61
+ raise AttributeError(
62
+ f"'{type(self).__name__}' object has no attribute '{key}'"
63
+ )
64
+ _PC.__getattr__ = _pc_getattr
65
+ _PC._jina_compat_patched = True
66
  # ──────────────────────────────────────────────────────────────────────────────
67
 
68
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form