e-chatbot / modules.py
SayedShaun's picture
Update modules.py
a66bf0e verified
import sys
import time
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import pickle
import glob
from pydantic import SecretStr
from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
import streamlit as st
# Streamlit functions to select the model
def model_selection()->Union[str, Tuple[str, SecretStr]]:
select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
if select == "Ollama":
return select, None
elif select == "Gemini":
api_key = st.sidebar.text_input("Enter API key", type="password")
return select, api_key
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
# This function will be used to time the functions
def timer(func)->callable:
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
end = time.perf_counter()
print(f"{func.__name__} took {end - start} seconds")
return result
return wrapper
@timer
def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
ollama_embeddings = OllamaEmbeddings(model="phi3")
if model == "Gemini":
if isinstance(document, str):
return np.array(google_embeddings.embed_query(document))
return np.array(google_embeddings.embed_documents(document))
elif model == "Ollama":
if isinstance(document, str):
return np.array(ollama_embeddings.embed_query(document))
return np.array(ollama_embeddings.embed_documents(document))
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
@timer
def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
dataframe = pd.read_csv(df)
if glob.glob("Data/context_matrix.pkl"):
pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
return pickle_file
else:
matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
print("Context matrix created, Run the code again")
return None
@timer
def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
dataframe = pd.read_csv(df)
matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
query_vector = text2embeddings(document=query, api_key=api_key, model=model)
score = matrix @ query_vector
top_k = np.argsort(score)[::-1][:chunk_size]
context = dataframe[column].iloc[top_k]
string = "\n".join(context)
return string
def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
if model == "Ollama":
return Ollama(model="phi3", temperature=temperature)
elif model == "Gemini":
return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
def get_prompt(context:str, query:str)->str:
prompt = f"""You are a chatbot named Mira. You answer user's queries about e-commerce products.
Here are some rules you will follow:
1. Your response will be concise and informative.
2. Do not generate any incomplete sentences or duplicate sentences.
3. If you give a list of products, do not list more than 5 products at once.
-----------------------------------------------
Context: '''{context}'''
-----------------------------------------------
Question: '''{query}'''
-----------------------------------------------
"""
return prompt
def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
if not api_key:
st.write("Please enter your key")
sys.exit(0)
context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
llm = select_llm(model=model, api_key=api_key, temperature=temperature)
prompt = get_prompt(context, query)
response = llm.stream(prompt)
return response