|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
from transformers import StoppingCriteriaList, StoppingCriteria |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from pinecone import Pinecone |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
|
|
|
|
model_name = "AI-Sweden-Models/gpt-sw3-1.3b-instruct" |
|
|
|
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
def read_file(file_path: str) -> str: |
|
|
"""Read the contents of a file.""" |
|
|
with open(file_path, "r") as file: |
|
|
return file.read() |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
model.eval() |
|
|
model.to(device) |
|
|
|
|
|
document_encoder_model = SentenceTransformer("KBLab/sentence-bert-swedish-cased") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pinecone_api_key = read_file("language_model\pinecone_api_key.txt") |
|
|
pc = Pinecone(api_key=pinecone_api_key) |
|
|
index = pc.Index("index1") |
|
|
|
|
|
|
|
|
def query_pincecone_namespace( |
|
|
vector_databse_index: Pinecone, q_embedding: str, namespace: str |
|
|
) -> str: |
|
|
result = vector_databse_index.query( |
|
|
namespace=namespace, |
|
|
vector=q_embedding.tolist(), |
|
|
top_k=1, |
|
|
include_values=True, |
|
|
include_metadata=True, |
|
|
) |
|
|
results = [] |
|
|
for match in result.matches: |
|
|
results.append(match.metadata["paragraph"]) |
|
|
return results[0] |
|
|
|
|
|
|
|
|
def generate_prompt(llmprompt: str) -> str: |
|
|
"""Generates a prompt for the GPT-3 model""" |
|
|
start_token = "<|endoftext|><s>" |
|
|
end_token = "<s>" |
|
|
return f"{start_token}\nUser:\n{llmprompt}\n{end_token}\nBot:\n".strip() |
|
|
|
|
|
|
|
|
def encode_query(query: str) -> torch.Tensor: |
|
|
"""Encode the query using the model's tokenizer""" |
|
|
return document_encoder_model.encode(query) |
|
|
|
|
|
|
|
|
class StopOnTokenCriteria(StoppingCriteria): |
|
|
def __init__(self, stop_token_id): |
|
|
self.stop_token_id = stop_token_id |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
return input_ids[0, -1] == self.stop_token_id |
|
|
|
|
|
|
|
|
stop_on_token_criteria = StopOnTokenCriteria(stop_token_id=tokenizer.bos_token_id) |
|
|
|
|
|
st.title("Paralegal Assistant") |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
|
|
|
|
|
|
for message in st.session_state.messages: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Skriv din fråga..."): |
|
|
|
|
|
st.chat_message("user").markdown(prompt) |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
query = query_pincecone_namespace( |
|
|
vector_databse_index=index, |
|
|
q_embedding=encode_query(query=prompt), |
|
|
namespace="ns-parent-balk", |
|
|
) |
|
|
llmprompt = ( |
|
|
"Besvara följande fråga på ett sakligt, kortfattat och formellt vis: " |
|
|
+ prompt |
|
|
+ "\n" |
|
|
+ "Använd följande text som referens när du besvarar frågan och hänvisa fakta i texten: \n" |
|
|
+ query |
|
|
) |
|
|
llmprompt = generate_prompt(llmprompt=llmprompt) |
|
|
|
|
|
|
|
|
input_ids = tokenizer(llmprompt, return_tensors="pt")["input_ids"].to(device) |
|
|
|
|
|
|
|
|
generated_token_ids = model.generate( |
|
|
inputs=input_ids, |
|
|
max_new_tokens=128, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
top_p=1, |
|
|
stopping_criteria=StoppingCriteriaList([stop_on_token_criteria]), |
|
|
)[0] |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]) : -1]) |
|
|
|
|
|
response = f"{generated_text}" |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st.markdown(response) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |