File size: 4,549 Bytes
ba03107
fe0fe03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a66bf0e
727ed61
 
a66bf0e
 
727ed61
 
 
 
 
 
fe0fe03
 
 
 
 
ba03107
fe0fe03
 
 
 
0c3cd21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import sys
import time
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import pickle
import glob
from pydantic import SecretStr
from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
import streamlit as st


# Streamlit functions to select the model
def model_selection()->Union[str, Tuple[str, SecretStr]]:
    select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
    if select == "Ollama":
        return select, None
    elif select == "Gemini":
        api_key = st.sidebar.text_input("Enter API key", type="password")
        return select, api_key
    else:
        raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
    

# This function will be used to time the functions
def timer(func)->callable:
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"{func.__name__} took {end - start} seconds")
        return result
    return wrapper

@timer
def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
    google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
    ollama_embeddings = OllamaEmbeddings(model="phi3")
    
    if model == "Gemini":
        if isinstance(document, str):
            return np.array(google_embeddings.embed_query(document))
        return np.array(google_embeddings.embed_documents(document))
    elif model == "Ollama":
        if isinstance(document, str):
            return np.array(ollama_embeddings.embed_query(document))
        return np.array(ollama_embeddings.embed_documents(document))
    else:
        raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")


@timer
def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
    dataframe = pd.read_csv(df)
    if glob.glob("Data/context_matrix.pkl"):
        pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
        return pickle_file 
    else:
        matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
        pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
        print("Context matrix created, Run the code again")
        return None
    

@timer
def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
    dataframe = pd.read_csv(df)
    matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
    query_vector = text2embeddings(document=query, api_key=api_key, model=model)
    score = matrix @ query_vector
    top_k = np.argsort(score)[::-1][:chunk_size]
    context = dataframe[column].iloc[top_k]
    string = "\n".join(context)
    return string


def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
    if model == "Ollama":
        return Ollama(model="phi3", temperature=temperature)
    elif model == "Gemini":
        return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
    else:
        raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")


def get_prompt(context:str, query:str)->str:
    prompt = f"""You are a chatbot named Mira. You answer user's queries about e-commerce products.
    Here are some rules you will follow:
    1. Your response will be concise and informative.
    2. Do not generate any incomplete sentences or duplicate sentences.
    3. If you give a list of products, do not list more than 5 products at once.
    -----------------------------------------------
    Context: '''{context}'''
    -----------------------------------------------
    Question: '''{query}'''
    -----------------------------------------------
    """
    return prompt

def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
    if not api_key:
        st.write("Please enter your key")
        sys.exit(0)
    context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
    llm = select_llm(model=model, api_key=api_key, temperature=temperature)
    prompt = get_prompt(context, query)
    response = llm.stream(prompt)
    return response