File size: 3,678 Bytes
7f7d885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from transformers import StoppingCriteriaList, StoppingCriteria
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

# model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
model_name = "AI-Sweden-Models/gpt-sw3-1.3b-instruct"


device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Initialize Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(model_name)


def read_file(file_path: str) -> str:
    """Read the contents of a file."""
    with open(file_path, "r") as file:
        return file.read()


model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model.to(device)

document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased")


# Note: 'index1' has been pre-created in the pinecone console
# read the pinecone api key from a file
pinecone_api_key = read_file("language_model\pincecone_api_key.txt")
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index("index1")


def get_user_input() -> str:
    """Get user input"""
    return input("Skriv in din fråga: ")


def query_pincecone_namespace(
    vector_databse_index: Pinecone, q_embedding: str, namespace: str
) -> str:
    result = vector_databse_index.query(
        namespace=namespace,
        vector=q_embedding.tolist(),
        top_k=1,
        include_values=True,
        include_metadata=True,
    )
    results = []
    for match in result.matches:
        results.append(match.metadata["paragraph"])
    return results[0]


def generate_prompt(prompt: str) -> str:
    """Generates a prompt for the GPT-3 model"""
    start_token = "<|endoftext|><s>"
    end_token = "<s>"
    return f"{start_token}\nUser:\n{prompt}\n{end_token}\nBot:\n".strip()


def encode_query(query: str) -> torch.Tensor:
    """Encode the query using the model's tokenizer"""
    return document_encoder_model.encode(query)


class StopOnTokenCriteria(StoppingCriteria):
    def __init__(self, stop_token_id):
        self.stop_token_id = stop_token_id

    def __call__(self, input_ids, scores, **kwargs):
        return input_ids[0, -1] == self.stop_token_id


stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id)


print(
    "Hej och välkommen till Specter! Fråga mig vad som helst om familjerätt \nså jag ska försöka svara på bästa sätt jag kan. Skriv 'exit' för att avsluta."
)
user_input = get_user_input()

while user_input != "exit":
    query = query_pincecone_namespace(
        vector_databse_index=index,
        q_embedding=encode_query(query=user_input),
        namespace="ns-parent-balk",
    )
    prompt = (
        "Besvara följande fråga på ett sakligt, kortfattat och formellt vis: "
        + user_input
        + "\n"
        + "Använd följande text som referens när du besvarar frågan och hänvisa fakta i texten: \n"
        + query
    )
    prompt = generate_prompt(prompt=prompt)
    print(prompt)

    # # Convert prompt to tokens
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)

    # Genqerate tokens based om prompt
    generated_token_ids = model.generate(
        inputs=input_ids,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.8,
        top_p=1,
        stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]),
    )[0]

    # Decode the generated tokens
    print("Generated answer: ")
    generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1])

    print(generated_text)
    user_input = get_user_input()