import sys import time from typing import Any, List, Optional, Tuple, Union import numpy as np import pandas as pd import pickle import glob from pydantic import SecretStr from langchain_community.llms.ollama import Ollama from langchain_community.embeddings import OllamaEmbeddings from langchain_google_genai.llms import GoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings import streamlit as st # Streamlit functions to select the model def model_selection()->Union[str, Tuple[str, SecretStr]]: select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"]) if select == "Ollama": return select, None elif select == "Gemini": api_key = st.sidebar.text_input("Enter API key", type="password") return select, api_key else: raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") # This function will be used to time the functions def timer(func)->callable: def wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) end = time.perf_counter() print(f"{func.__name__} took {end - start} seconds") return result return wrapper @timer def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray: google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) ollama_embeddings = OllamaEmbeddings(model="phi3") if model == "Gemini": if isinstance(document, str): return np.array(google_embeddings.embed_query(document)) return np.array(google_embeddings.embed_documents(document)) elif model == "Ollama": if isinstance(document, str): return np.array(ollama_embeddings.embed_query(document)) return np.array(ollama_embeddings.embed_documents(document)) else: raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") @timer def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any: dataframe = pd.read_csv(df) if glob.glob("Data/context_matrix.pkl"): pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb")) return pickle_file else: matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model) pickle.dump(matrix, open("Data/context_matrix.pkl", "wb")) print("Context matrix created, Run the code again") return None @timer def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str: dataframe = pd.read_csv(df) matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model) query_vector = text2embeddings(document=query, api_key=api_key, model=model) score = matrix @ query_vector top_k = np.argsort(score)[::-1][:chunk_size] context = dataframe[column].iloc[top_k] string = "\n".join(context) return string def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]: if model == "Ollama": return Ollama(model="phi3", temperature=temperature) elif model == "Gemini": return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key) else: raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") def get_prompt(context:str, query:str)->str: prompt = f"""You are a chatbot named Mira. You answer user's queries about e-commerce products. Here are some rules you will follow: 1. Your response will be concise and informative. 2. Do not generate any incomplete sentences or duplicate sentences. 3. If you give a list of products, do not list more than 5 products at once. ----------------------------------------------- Context: '''{context}''' ----------------------------------------------- Question: '''{query}''' ----------------------------------------------- """ return prompt def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str: if not api_key: st.write("Please enter your key") sys.exit(0) context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size) llm = select_llm(model=model, api_key=api_key, temperature=temperature) prompt = get_prompt(context, query) response = llm.stream(prompt) return response