SayedShaun commited on
Commit
0c3cd21
·
verified ·
1 Parent(s): fa3b041

Upload modules.py

Browse files
Files changed (1) hide show
  1. modules.py +108 -0
modules.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any, List, Optional, Tuple, Union
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pickle
6
+ import glob
7
+ from pydantic import SecretStr
8
+ from langchain_community.llms.ollama import Ollama
9
+ from langchain_community.embeddings import OllamaEmbeddings
10
+ from langchain_google_genai.llms import GoogleGenerativeAI
11
+ from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
12
+ import streamlit as st
13
+
14
+
15
+ # Streamlit functions to select the model
16
+ def model_selection()->Union[str, Tuple[str, SecretStr]]:
17
+ select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
18
+ if select == "Ollama":
19
+ return select, None
20
+ elif select == "Gemini":
21
+ api_key = st.sidebar.text_input("Enter API key", type="password")
22
+ return select, api_key
23
+ else:
24
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
25
+
26
+
27
+ # This function will be used to time the functions
28
+ def timer(func)->callable:
29
+ def wrapper(*args, **kwargs):
30
+ start = time.perf_counter()
31
+ result = func(*args, **kwargs)
32
+ end = time.perf_counter()
33
+ print(f"{func.__name__} took {end - start} seconds")
34
+ return result
35
+ return wrapper
36
+
37
+ @timer
38
+ def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
39
+ google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
40
+ ollama_embeddings = OllamaEmbeddings(model="phi3")
41
+
42
+ if model == "Gemini":
43
+ if isinstance(document, str):
44
+ return np.array(google_embeddings.embed_query(document))
45
+ return np.array(google_embeddings.embed_documents(document))
46
+ elif model == "Ollama":
47
+ if isinstance(document, str):
48
+ return np.array(ollama_embeddings.embed_query(document))
49
+ return np.array(ollama_embeddings.embed_documents(document))
50
+ else:
51
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
52
+
53
+
54
+ @timer
55
+ def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
56
+ dataframe = pd.read_csv(df)
57
+ if glob.glob("Data/context_matrix.pkl"):
58
+ pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
59
+ return pickle_file
60
+ else:
61
+ matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
62
+ pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
63
+ print("Context matrix created, Run the code again")
64
+ return None
65
+
66
+
67
+ @timer
68
+ def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
69
+ dataframe = pd.read_csv(df)
70
+ matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
71
+ query_vector = text2embeddings(document=query, api_key=api_key, model=model)
72
+ score = matrix @ query_vector
73
+ top_k = np.argsort(score)[::-1][:chunk_size]
74
+ context = dataframe[column].iloc[top_k]
75
+ string = "\n".join(context)
76
+ return string
77
+
78
+
79
+ def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
80
+ if model == "Ollama":
81
+ return Ollama(model="phi3", temperature=temperature)
82
+ elif model == "Gemini":
83
+ return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
84
+ else:
85
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
86
+
87
+
88
+ def get_prompt(context:str, query:str)->str:
89
+ prompt = f"""You are a chatbot named Mira. You answer user's query about ecommerce products.
90
+ Here are some rules you will follow:
91
+ 1. Your response will be concise and informative.
92
+ 2. You must answer the user's query from the given context.
93
+ 3. Do not generate any incomplete sentences or duplicate sentences.
94
+
95
+ -----------------------------------------------
96
+ Context: '''{context}'''
97
+ -----------------------------------------------
98
+ Question: '''{query}'''
99
+ -----------------------------------------------
100
+ """
101
+ return prompt
102
+
103
+ def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
104
+ context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
105
+ llm = select_llm(model=model, api_key=api_key, temperature=temperature)
106
+ prompt = get_prompt(context, query)
107
+ response = llm.stream(prompt)
108
+ return response