e-chatbot / modules.py
SayedShaun's picture
Update modules.py
59bc417 verified
raw
history blame
5.23 kB
import sys
import time
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import pickle
import glob
from pydantic import SecretStr
from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai.llms import GoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
import streamlit as st
# Streamlit functions to select the model
def model_selection()->Union[str, Tuple[str, SecretStr]]:
select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
if select == "Ollama":
return select, None
elif select == "Gemini":
api_key = st.sidebar.text_input("Enter API key", type="password")
return select, api_key
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
# This function will be used to time the functions
def timer(func)->callable:
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
end = time.perf_counter()
print(f"{func.__name__} took {end - start} seconds")
return result
return wrapper
@timer
def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
ollama_embeddings = OllamaEmbeddings(model="phi3")
if model == "Gemini":
if isinstance(document, str):
return np.array(google_embeddings.embed_query(document))
return np.array(google_embeddings.embed_documents(document))
elif model == "Ollama":
if isinstance(document, str):
return np.array(ollama_embeddings.embed_query(document))
return np.array(ollama_embeddings.embed_documents(document))
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
@timer
def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
dataframe = pd.read_csv(df)
if glob.glob("Data/context_matrix.pkl"):
pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
return pickle_file
else:
matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
print("Context matrix created, Run the code again")
return None
@timer
def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
dataframe = pd.read_csv(df)
matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
query_vector = text2embeddings(document=query, api_key=api_key, model=model)
score = matrix @ query_vector
top_k = np.argsort(score)[::-1][:chunk_size]
context = dataframe[column].iloc[top_k]
string = "\n".join(context)
return string
def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
if model == "Ollama":
return Ollama(model="phi3", temperature=temperature)
elif model == "Gemini":
return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
else:
raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
def get_prompt(context:str, query:str)->str:
# prompt = f"""You are a chatbot named Mira. You answer user's query about ecommerce products.
# Here are some rules you will follow:
# 1. Your response will be concise and informative.
# 2. You must answer the user's query from the given context.
# 3. Do not generate any incomplete sentences or duplicate sentences.
# 4. If you give a list of product do not list more than 5 product at once.
# -----------------------------------------------
# Context: '''{context}'''
# -----------------------------------------------
# Question: '''{query}'''
# -----------------------------------------------
# """
# return prompt
prompt = f"""As Mira, an expert e-commerce chatbot, your role is to provide users with precise and informative responses to their queries about products. You are required to answer questions based solely on the provided context, ensuring that your responses are both concise and relevant. It's essential to communicate clearly, avoiding any incomplete or repetitive sentences. Your goal is to deliver a seamless and insightful user experience by offering well-structured answers that directly address the user's needs.
Context: `{context}`
Question: `{query}`"""
return prompt
def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
if not api_key:
st.write("Please enter your key")
sys.exit(0)
context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
llm = select_llm(model=model, api_key=api_key, temperature=temperature)
prompt = get_prompt(context, query)
response = llm.stream(prompt)
return response