Update main.py
Browse files
main.py
CHANGED
|
@@ -12,8 +12,34 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|
| 12 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 13 |
os.environ["HF_HUB_VERBOSITY"] = "error"
|
| 14 |
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import faiss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 18 |
from fastapi.middleware.cors import CORSMiddleware
|
| 19 |
from pydantic import BaseModel, Field
|
|
|
|
| 12 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 13 |
os.environ["HF_HUB_VERBOSITY"] = "error"
|
| 14 |
|
| 15 |
+
import torch
|
| 16 |
import numpy as np
|
| 17 |
import faiss
|
| 18 |
+
|
| 19 |
+
# ββ Compatibility patch ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
# `find_pruneable_heads_and_indices` was removed from transformers.pytorch_utils
|
| 21 |
+
# in newer releases, but the jina-bert-v2 custom modeling code still imports it
|
| 22 |
+
# from there. We re-inject the original implementation before anything loads.
|
| 23 |
+
import transformers.pytorch_utils as _pt_utils
|
| 24 |
+
if not hasattr(_pt_utils, "find_pruneable_heads_and_indices"):
|
| 25 |
+
from typing import Set, List, Tuple
|
| 26 |
+
def _find_pruneable_heads_and_indices(
|
| 27 |
+
heads: List[int],
|
| 28 |
+
n_heads: int,
|
| 29 |
+
head_size: int,
|
| 30 |
+
already_pruned_heads: Set[int],
|
| 31 |
+
) -> Tuple[Set[int], torch.LongTensor]:
|
| 32 |
+
mask = torch.ones(n_heads, head_size)
|
| 33 |
+
heads = set(heads) - already_pruned_heads
|
| 34 |
+
for head in heads:
|
| 35 |
+
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 36 |
+
mask[head] = 0
|
| 37 |
+
mask = mask.view(-1).contiguous().eq(1)
|
| 38 |
+
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
| 39 |
+
return heads, index
|
| 40 |
+
_pt_utils.find_pruneable_heads_and_indices = _find_pruneable_heads_and_indices
|
| 41 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
| 44 |
from fastapi.middleware.cors import CORSMiddleware
|
| 45 |
from pydantic import BaseModel, Field
|