dzenzzz commited on
Commit
8385b36
·
1 Parent(s): 5834fcc

updates app

Browse files
Files changed (3) hide show
  1. app.py +20 -41
  2. ner.py +0 -96
  3. senatus_client.py +0 -136
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import asyncio
2
  import time
3
  from fastapi import FastAPI, Request, HTTPException
4
- # # from fastapi.responses import JSONResponse
5
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
6
  from neural_searcher import NeuralSearcher
7
- # from fastapi.middleware.cors import CORSMiddleware
8
  from huggingface_hub import login
9
- from config import HUGGING_FACE_API_KEY,COLLECTION_NAME, ALLOWED_ORIGINS, API_KEY
10
 
11
  login(HUGGING_FACE_API_KEY)
12
 
@@ -14,53 +13,33 @@ app = FastAPI()
14
 
15
  neural_searcher = NeuralSearcher(collection_name=COLLECTION_NAME)
16
 
17
- # REQUEST_TIMEOUT_ERROR = 30
18
 
19
- # ALLOWED_ORIGINS = [ALLOWED_ORIGINS]
20
- # ALLOWED_API_KEY = API_KEY
21
 
22
 
23
  @app.get("/api/search")
24
  async def search(q: str):
25
  data = await neural_searcher.search(text=q)
26
  return data
27
-
28
 
29
- # app.add_middleware(
30
- # CORSMiddleware,
31
- # allow_origins=ALLOWED_ORIGINS,
32
- # allow_credentials=True,
33
- # allow_methods=["GET"],
34
- # allow_headers=["*"],
35
- # )
36
 
37
- # @app.middleware("http")
38
- # async def security_middleware(request: Request, call_next):
39
- # referer = request.headers.get("referer", "")
40
- # origin = request.headers.get("origin", "")
41
- # user_agent = request.headers.get("user-agent", "")
42
- # api_key = request.headers.get("X-API-KEY", "")
43
 
44
-
45
- # if not (referer.startswith(ALLOWED_ORIGINS[0]) or origin.startswith(ALLOWED_ORIGINS[0])):
46
- # raise HTTPException(status_code=403, detail="Access denied: Invalid source")
47
 
48
- # if not user_agent or "Mozilla" not in user_agent:
49
- # raise HTTPException(status_code=403, detail="Access denied: Suspicious client")
 
 
 
50
 
51
- # if api_key != ALLOWED_API_KEY:
52
- # raise HTTPException(status_code=403, detail="Access denied: Invalid API Key")
53
-
54
- # return await call_next(request)
55
-
56
- # @app.middleware("http")
57
- # async def timeout_middleware(request: Request, call_next):
58
- # try:
59
- # start_time = time.time()
60
- # return await asyncio.wait_for(call_next(request), timeout=REQUEST_TIMEOUT_ERROR)
61
-
62
- # except asyncio.TimeoutError:
63
- # process_time = time.time() - start_time
64
- # return JSONResponse({'detail': 'Request processing time excedeed limit',
65
- # 'processing_time': process_time},
66
- # status_code=HTTP_504_GATEWAY_TIMEOUT)
 
1
  import asyncio
2
  import time
3
  from fastapi import FastAPI, Request, HTTPException
4
+ from fastapi.responses import JSONResponse
5
  from starlette.status import HTTP_504_GATEWAY_TIMEOUT
6
  from neural_searcher import NeuralSearcher
 
7
  from huggingface_hub import login
8
+ from config import HUGGING_FACE_API_KEY,COLLECTION_NAME, API_KEY
9
 
10
  login(HUGGING_FACE_API_KEY)
11
 
 
13
 
14
  neural_searcher = NeuralSearcher(collection_name=COLLECTION_NAME)
15
 
16
+ REQUEST_TIMEOUT_ERROR = 30
17
 
18
+ ALLOWED_API_KEY = API_KEY
 
19
 
20
 
21
  @app.get("/api/search")
22
  async def search(q: str):
23
  data = await neural_searcher.search(text=q)
24
  return data
 
25
 
26
+ @app.middleware("http")
27
+ async def security_middleware(request: Request, call_next):
28
+ api_key = request.headers.get("X-API-KEY", "")
 
 
 
 
29
 
30
+ if api_key != ALLOWED_API_KEY:
31
+ raise HTTPException(status_code=403, detail="Access denied.")
 
 
 
 
32
 
33
+ return await call_next(request)
 
 
34
 
35
+ @app.middleware("http")
36
+ async def timeout_middleware(request: Request, call_next):
37
+ try:
38
+ start_time = time.time()
39
+ return await asyncio.wait_for(call_next(request), timeout=REQUEST_TIMEOUT_ERROR)
40
 
41
+ except asyncio.TimeoutError:
42
+ process_time = time.time() - start_time
43
+ return JSONResponse({'detail': 'Request processing time excedeed limit',
44
+ 'processing_time': process_time},
45
+ status_code=HTTP_504_GATEWAY_TIMEOUT)
 
 
 
 
 
 
 
 
 
 
 
ner.py DELETED
@@ -1,96 +0,0 @@
1
- from transformers import AutoModelForTokenClassification, AutoTokenizer
2
- from config import NER_MODEL
3
- import torch
4
-
5
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
- tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_auth_token=True)
7
- model = AutoModelForTokenClassification.from_pretrained(NER_MODEL, use_auth_token=True).to(device)
8
-
9
- id_to_label = {
10
- 0: 'O',
11
- 1: 'B-COURT',
12
- 2: 'B-DATE',
13
- 3: 'B-DECISION',
14
- 4: 'B-LAW',
15
- 5: 'B-MONEY',
16
- 6: 'B-OFFICIAL GAZZETE',
17
- 7: 'B-PERSON',
18
- 8: 'B-REFERENCE',
19
- 9: 'I-COURT',
20
- 10: 'I-LAW',
21
- 11: 'I-MONEY',
22
- 12: 'I-OFFICIAL GAZZETE',
23
- 13: 'I-PERSON',
24
- 14: 'I-REFERENCE'
25
- }
26
-
27
- def perform_ner(text):
28
- try:
29
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
- logits = outputs.logits
33
- predictions = torch.argmax(logits, dim=2).squeeze().tolist()
34
-
35
- except RuntimeError as e:
36
- if "CUDA out of memory" in str(e):
37
- print("Switching to CPU due to memory constraints.")
38
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
39
- with torch.no_grad():
40
- outputs = model.cpu()(**inputs) # Run model on CPU
41
- logits = outputs.logits
42
- predictions = torch.argmax(logits, dim=2).squeeze().tolist()
43
- else:
44
- raise e
45
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
46
- labels = [id_to_label[pred] for pred in predictions]
47
-
48
- results = [
49
- (token, label)
50
- for token, label in zip(tokens, labels)
51
- if token not in tokenizer.all_special_tokens
52
- ]
53
- return results
54
-
55
- text = ""
56
-
57
- def merge_entities(token_label_pairs):
58
- merged_words, merged_labels = [], []
59
- current_word, current_label = "", None
60
-
61
- for token, label in token_label_pairs:
62
- if token.startswith("##"):
63
- current_word += token[2:]
64
- else:
65
- if current_word:
66
- merged_words.append(current_word)
67
- merged_labels.append(current_label)
68
-
69
- current_word, current_label = token, label
70
-
71
- if current_word:
72
- merged_words.append(current_word)
73
- merged_labels.append(current_label)
74
-
75
- final_words, final_labels = [], []
76
-
77
- for i, (word, label) in enumerate(zip(merged_words, merged_labels)):
78
- if final_labels and (
79
- label == final_labels[-1] or
80
- (label.startswith("I-") and final_labels[-1].endswith(label[2:])) or
81
- (label.startswith("B-") and final_labels[-1].endswith(label[2:]))
82
- ):
83
-
84
- final_words[-1] += " " + word
85
- else:
86
- final_words.append(word)
87
- final_labels.append(label)
88
-
89
- return final_words, final_labels
90
-
91
- results = perform_ner(text)
92
-
93
- words,labels = merge_entities(results)
94
-
95
- for i,b in zip(words,labels):
96
- print(i + " ### " + b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
senatus_client.py DELETED
@@ -1,136 +0,0 @@
1
- import json
2
- import uuid
3
- import numpy as np
4
- import os
5
- from huggingface_hub import login
6
- from fastembed import SparseTextEmbedding,LateInteractionTextEmbedding
7
- from qdrant_client import QdrantClient, models
8
- from sentence_transformers import SentenceTransformer
9
- from tqdm import tqdm
10
- from huggingface_hub import login
11
- from config import HUGGING_FACE_API_KEY, DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME
12
-
13
- login(HUGGING_FACE_API_KEY)
14
-
15
- folder_path = 'data'
16
-
17
- dense_model = SentenceTransformer(DENSE_MODEL)
18
- sparse_model = SparseTextEmbedding(SPARSE_MODEL)
19
- # late_interaction_embedding_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
20
-
21
- data = []
22
- for filename in os.listdir(folder_path):
23
- if filename.endswith('.json'):
24
- file_path = os.path.join(folder_path, filename)
25
- with open(file_path,encoding='utf-8') as f:
26
- data = json.load(f)
27
-
28
-
29
- client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY)
30
-
31
-
32
- data_array = np.array(data)
33
-
34
- split_data = np.array_split(data_array, 1000)
35
-
36
- collection_name = COLLECTION_NAME
37
-
38
- for local_data in split_data:
39
-
40
- payload = []
41
- documents = []
42
- for obj in local_data:
43
- documents.append(obj["tekst"])
44
- payload.append(obj)
45
-
46
- sparse_embeddings = list(
47
- tqdm(
48
- sparse_model.passage_embed(doc for doc in documents),
49
- total=len(documents),
50
- desc="🔨 Encoding Sparse Embeddings"
51
- )
52
- )
53
-
54
- # late_interaction_embeddings = list(
55
- # tqdm(
56
- # late_interaction_embedding_model.passage_embed(doc for doc in documents),
57
- # total=len(documents),
58
- # desc="🔨 Encoding Late Interaction Embeddings"
59
- # )
60
- # )
61
-
62
- dense_embeddings = dense_model.encode(documents, show_progress_bar=True, device="cuda")
63
-
64
- existing_collections = client.get_collections().collections
65
- collection_names = [col.name for col in existing_collections]
66
-
67
- if collection_name not in collection_names:
68
- client.create_collection(
69
- collection_name=collection_name,
70
- vectors_config={
71
- DENSE_MODEL: models.VectorParams(
72
- size=len(dense_embeddings[0]),
73
- distance=models.Distance.COSINE,
74
- on_disk=True
75
- ),
76
- # LATE_INTERACTION_MODEL: models.VectorParams(
77
- # size=len(late_interaction_embeddings[0][0]),
78
- # distance=models.Distance.COSINE,
79
- # multivector_config=models.MultiVectorConfig(
80
- # comparator=models.MultiVectorComparator.MAX_SIM,
81
- # ),
82
- # hnsw_config=models.HnswConfigDiff(
83
- # m=0, # Disable HNSW graph creation
84
- # ),
85
- # on_disk=True
86
- # ),
87
- },
88
- sparse_vectors_config={
89
- SPARSE_MODEL: models.SparseVectorParams(
90
- modifier=models.Modifier.IDF,
91
- ),
92
- },
93
- quantization_config=models.ScalarQuantization(
94
- scalar=models.ScalarQuantizationConfig(
95
- type=models.ScalarType.INT8,
96
- always_ram=True
97
- )
98
- ),
99
- optimizers_config=models.OptimizersConfigDiff(
100
- indexing_threshold=10000,
101
- ),
102
- shard_number = 4,
103
- hnsw_config=models.HnswConfigDiff(on_disk=True),
104
- )
105
-
106
- print("🚀 Uploading to qdrant collection: " + collection_name)
107
- client.upload_points(
108
- collection_name=collection_name,
109
- batch_size = 32,
110
- parallel = 16,
111
- points=[
112
- models.PointStruct(
113
- id=uuid.uuid4().hex,
114
- vector={
115
- DENSE_MODEL: dense_embedding,
116
- SPARSE_MODEL: sparse_embedding.as_object(),
117
- # "answerdotai/answerai-colbert-small-v1":late_interaction_embedding
118
- },
119
- payload=doc,
120
- )
121
- for doc, dense_embedding, sparse_embedding in zip(
122
- payload, dense_embeddings, sparse_embeddings
123
- )
124
- ],
125
- )
126
-
127
- client.create_payload_index(
128
- collection_name=collection_name,
129
- field_name="dbid",
130
- field_schema=models.PayloadSchemaType.INTEGER
131
- )
132
-
133
- client.update_collection(
134
- collection_name=collection_name,
135
- optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
136
- )