kamp0010 commited on
Commit
1e4d629
Β·
verified Β·
1 Parent(s): 8079c4e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +26 -0
main.py CHANGED
@@ -12,8 +12,34 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
  os.environ["HF_HUB_VERBOSITY"] = "error"
14
 
 
15
  import numpy as np
16
  import faiss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel, Field
 
12
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
13
  os.environ["HF_HUB_VERBOSITY"] = "error"
14
 
15
+ import torch
16
  import numpy as np
17
  import faiss
18
+
19
+ # ── Compatibility patch ────────────────────────────────────────────────────────
20
+ # `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils
21
+ # in newer releases, but the jina-bert-v2 custom modeling code still imports it
22
+ # from there. We re-inject the original implementation before anything loads.
23
+ import transformers.pytorch_utils as _pt_utils
24
+ if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
25
+ from typing import Set, List, Tuple
26
+ def _find_pruneable_heads_and_indices(
27
+ heads: List[int],
28
+ n_heads: int,
29
+ head_size: int,
30
+ already_pruned_heads: Set[int],
31
+ ) -> Tuple[Set[int], torch.LongTensor]:
32
+ mask = torch.ones(n_heads, head_size)
33
+ heads = set(heads) - already_pruned_heads
34
+ for head in heads:
35
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
36
+ mask[head] = 0
37
+ mask = mask.view(-1).contiguous().eq(1)
38
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
39
+ return heads, index
40
+ _pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
41
+ # ──────────────────────────────────────────────────────────────────────────────
42
+
43
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form
44
  from fastapi.middleware.cors import CORSMiddleware
45
  from pydantic import BaseModel, Field