theachyuttiwari commited on
Commit
07cb325
·
1 Parent(s): 0d55c99

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -35
main.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
  from fastapi import FastAPI, Depends, status
3
  from fastapi.responses import PlainTextResponse
4
- from transformers import AutoTokenizer, AutoModel, DPRQuestionEncoder
 
5
 
6
- from datasets import load_from_disk
7
  import time
8
- from typing import Dict
9
 
10
  import jwt
11
  from decouple import config
@@ -17,22 +17,12 @@ JWT_ALGORITHM = config("algorithm")
17
 
18
  app = FastAPI()
19
  app.ready = False
20
- columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
21
- 'wikidata_info', 'history']
22
 
23
- min_snippet_length = 20
24
- topk = 21
25
  device = ("cuda" if torch.cuda.is_available() else "cpu")
26
- model = DPRQuestionEncoder.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki").to(device)
27
- tokenizer = AutoTokenizer.from_pretrained("vblagoje/dpr-question_encoder-single-lfqa-wiki")
28
  _ = model.eval()
29
 
30
- index_file_name = "./data/kilt_wikipedia.faiss"
31
-
32
- kilt_wikipedia_paragraphs = load_from_disk("./data/kilt_wiki_prepared")
33
- # use paragraphs that are not simple fragments or very short sentences
34
- kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 200)
35
-
36
 
37
  class JWTBearer(HTTPBearer):
38
  def __init__(self, auto_error: bool = True):
@@ -85,27 +75,26 @@ def decodeJWT(token: str) -> dict:
85
  return {}
86
 
87
 
88
- def embed_questions_for_retrieval(questions):
89
- query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
90
- with torch.no_grad():
91
- q_reps = model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output
92
- return q_reps.cpu().numpy()
 
 
 
 
 
 
93
 
94
- def query_index(question):
95
- question_embedding = embed_questions_for_retrieval([question])
96
- scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
97
- columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id',
98
- 'start_character', 'end_character']
99
- retrieved_examples = []
100
- r = list(zip(wiki_passages[k] for k in columns))
101
- for i in range(topk):
102
- retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
103
- return retrieved_examples
104
 
105
 
106
  @app.on_event("startup")
107
  def startup():
108
- kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
109
  app.ready = True
110
 
111
 
@@ -116,7 +105,26 @@ def healthz():
116
  return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
117
 
118
 
119
- @app.get("/find_context", dependencies=[Depends(JWTBearer())])
120
- def find_context(question: str = None):
121
- return [res for res in query_index(question) if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from fastapi import FastAPI, Depends, status
3
  from fastapi.responses import PlainTextResponse
4
+ from pydantic import BaseModel
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
 
7
  import time
8
+ from typing import Dict, List, Optional
9
 
10
  import jwt
11
  from decouple import config
 
17
 
18
  app = FastAPI()
19
  app.ready = False
 
 
20
 
 
 
21
  device = ("cuda" if torch.cuda.is_available() else "cpu")
22
+ tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_lfqa')
23
+ model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_lfqa').to(device)
24
  _ = model.eval()
25
 
 
 
 
 
 
 
26
 
27
  class JWTBearer(HTTPBearer):
28
  def __init__(self, auto_error: bool = True):
 
75
  return {}
76
 
77
 
78
+ class LFQAParameters(BaseModel):
79
+ min_length: int = 50
80
+ max_length: int = 250
81
+ do_sample: bool = False
82
+ early_stopping: bool = True
83
+ num_beams: int = 8
84
+ temperature: float = 1.0
85
+ top_k: float = None
86
+ top_p: float = None
87
+ no_repeat_ngram_size: int = 3
88
+ num_return_sequences: int = 1
89
 
90
+
91
+ class InferencePayload(BaseModel):
92
+ model_input: str
93
+ parameters: Optional[LFQAParameters] = LFQAParameters()
 
 
 
 
 
 
94
 
95
 
96
  @app.on_event("startup")
97
  def startup():
 
98
  app.ready = True
99
 
100
 
 
105
  return PlainTextResponse("service unavailable", status_code=status.HTTP_503_SERVICE_UNAVAILABLE)
106
 
107
 
108
+ @app.post("/generate/", dependencies=[Depends(JWTBearer())])
109
+ def generate(context: InferencePayload):
110
+
111
+ model_input = tokenizer(context.model_input, truncation=True, padding=True, return_tensors="pt")
112
+ param = context.parameters
113
+ generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
114
+ attention_mask=model_input["attention_mask"].to(device),
115
+ min_length=param.min_length,
116
+ max_length=param.max_length,
117
+ do_sample=param.do_sample,
118
+ early_stopping=param.early_stopping,
119
+ num_beams=param.num_beams,
120
+ temperature=param.temperature,
121
+ top_k=param.top_k,
122
+ top_p=param.top_p,
123
+ no_repeat_ngram_size=param.no_repeat_ngram_size,
124
+ num_return_sequences=param.num_return_sequences)
125
+ answers = tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
126
+ clean_up_tokenization_spaces=True)
127
+ results = []
128
+ for answer in answers:
129
+ results.append({"generated_text": answer})
130
+ return results