Junaidb commited on
Commit
f983ed9
·
verified ·
1 Parent(s): ba8fbb1

Create qa_agent.py

Browse files
Files changed (1) hide show
  1. qa_agent.py +26 -0
qa_agent.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained("Intel/dynamic_tinybert")
5
+ model = AutoModelForQuestionAnswering.from_pretrained("Intel/dynamic_tinybert")
6
+
7
+ def QA_Agent(context,question):
8
+
9
+
10
+ #Tokenize the context and question
11
+ tokens = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True)
12
+
13
+ #Get the input IDs and attention mask
14
+ input_ids = tokens["input_ids"]
15
+ attention_mask = tokens["attention_mask"]
16
+
17
+ #Perform question answering
18
+ outputs = model(input_ids, attention_mask=attention_mask)
19
+ start_scores = outputs.start_logits
20
+ end_scores = outputs.end_logits
21
+
22
+ #Find the start and end positions of the answer
23
+ answer_start = torch.argmax(start_scores)
24
+ answer_end = torch.argmax(end_scores) + 1
25
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[0][answer_start:answer_end]))
26
+ return answer