perbon commited on
Commit
7f7d885
·
verified ·
1 Parent(s): 487fd83

Upload chatbot that works with pinecone

Browse files
Files changed (1) hide show
  1. chat_bot_pinecone.py +123 -0
chat_bot_pinecone.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
+ from transformers import StoppingCriteriaList, StoppingCriteria
4
+ from sentence_transformers import SentenceTransformer
5
+ from pinecone import Pinecone
6
+ import warnings
7
+
8
+ warnings.filterwarnings("ignore", category=UserWarning)
9
+
10
+ # model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
11
+ model_name = "AI-Sweden-Models/gpt-sw3-1.3b-instruct"
12
+
13
+
14
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Initialize Tokenizer & Model
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+
20
+ def read_file(file_path: str) -> str:
21
+ """Read the contents of a file."""
22
+ with open(file_path, "r") as file:
23
+ return file.read()
24
+
25
+
26
+ model = AutoModelForCausalLM.from_pretrained(model_name)
27
+ model.eval()
28
+ model.to(device)
29
+
30
+ document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")
31
+
32
+
33
+ # Note: 'index1' has been pre-created in the pinecone console
34
+ # read the pinecone api key from a file
35
+ pinecone_api_key = read_file("language_model\pincecone_api_key.txt")
36
+ pc = Pinecone(api_key=pinecone_api_key)
37
+ index = pc.Index("index1")
38
+
39
+
40
+ def get_user_input() -> str:
41
+ """Get user input"""
42
+ return input("Skriv in din fråga: ")
43
+
44
+
45
+ def query_pincecone_namespace(
46
+ vector_databse_index: Pinecone, q_embedding: str, namespace: str
47
+ ) -> str:
48
+ result = vector_databse_index.query(
49
+ namespace=namespace,
50
+ vector=q_embedding.tolist(),
51
+ top_k=1,
52
+ include_values=True,
53
+ include_metadata=True,
54
+ )
55
+ results = []
56
+ for match in result.matches:
57
+ results.append(match.metadata["paragraph"])
58
+ return results[0]
59
+
60
+
61
+ def generate_prompt(prompt: str) -> str:
62
+ """Generates a prompt for the GPT-3 model"""
63
+ start_token = "<|endoftext|><s>"
64
+ end_token = "<s>"
65
+ return f"{start_token}\nUser:\n{prompt}\n{end_token}\nBot:\n".strip()
66
+
67
+
68
+ def encode_query(query: str) -> torch.Tensor:
69
+ """Encode the query using the model's tokenizer"""
70
+ return document_encoder_model.encode(query)
71
+
72
+
73
+ class StopOnTokenCriteria(StoppingCriteria):
74
+ def __init__(self, stop_token_id):
75
+ self.stop_token_id = stop_token_id
76
+
77
+ def __call__(self, input_ids, scores, **kwargs):
78
+ return input_ids[0, -1] == self.stop_token_id
79
+
80
+
81
+ stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)
82
+
83
+
84
+ print(
85
+ "Hej och välkommen till Specter! Fråga mig vad som helst om familjerätt \nså jag ska försöka svara på bästa sätt jag kan. Skriv 'exit' för att avsluta."
86
+ )
87
+ user_input = get_user_input()
88
+
89
+ while user_input != "exit":
90
+ query = query_pincecone_namespace(
91
+ vector_databse_index=index,
92
+ q_embedding=encode_query(query=user_input),
93
+ namespace="ns-parent-balk",
94
+ )
95
+ prompt = (
96
+ "Besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
97
+ + user_input
98
+ + "\n"
99
+ + "Använd följande text som referens när du besvarar frågan och hänvisa fakta i texten: \n"
100
+ + query
101
+ )
102
+ prompt = generate_prompt(prompt=prompt)
103
+ print(prompt)
104
+
105
+ # # Convert prompt to tokens
106
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
107
+
108
+ # Genqerate tokens based om prompt
109
+ generated_token_ids = model.generate(
110
+ inputs=input_ids,
111
+ max_new_tokens=128,
112
+ do_sample=True,
113
+ temperature=0.8,
114
+ top_p=1,
115
+ stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
116
+ )[0]
117
+
118
+ # Decode the generated tokens
119
+ print("Generated answer: ")
120
+ generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])
121
+
122
+ print(generated_text)
123
+ user_input = get_user_input()