SayedShaun commited on
Commit
fe0fe03
·
verified ·
1 Parent(s): 042f944

Update modules.py

Browse files
Files changed (1) hide show
  1. modules.py +109 -107
modules.py CHANGED
@@ -1,108 +1,110 @@
1
- import time
2
- from typing import Any, List, Optional, Tuple, Union
3
- import numpy as np
4
- import pandas as pd
5
- import pickle
6
- import glob
7
- from pydantic import SecretStr
8
- from langchain_community.llms.ollama import Ollama
9
- from langchain_community.embeddings import OllamaEmbeddings
10
- from langchain_google_genai.llms import GoogleGenerativeAI
11
- from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
12
- import streamlit as st
13
-
14
-
15
- # Streamlit functions to select the model
16
- def model_selection()->Union[str, Tuple[str, SecretStr]]:
17
- select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
18
- if select == "Ollama":
19
- return select, None
20
- elif select == "Gemini":
21
- api_key = st.sidebar.text_input("Enter API key", type="password")
22
- return select, api_key
23
- else:
24
- raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
25
-
26
-
27
- # This function will be used to time the functions
28
- def timer(func)->callable:
29
- def wrapper(*args, **kwargs):
30
- start = time.perf_counter()
31
- result = func(*args, **kwargs)
32
- end = time.perf_counter()
33
- print(f"{func.__name__} took {end - start} seconds")
34
- return result
35
- return wrapper
36
-
37
- @timer
38
- def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
39
- google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
40
- ollama_embeddings = OllamaEmbeddings(model="phi3")
41
-
42
- if model == "Gemini":
43
- if isinstance(document, str):
44
- return np.array(google_embeddings.embed_query(document))
45
- return np.array(google_embeddings.embed_documents(document))
46
- elif model == "Ollama":
47
- if isinstance(document, str):
48
- return np.array(ollama_embeddings.embed_query(document))
49
- return np.array(ollama_embeddings.embed_documents(document))
50
- else:
51
- raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
52
-
53
-
54
- @timer
55
- def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
56
- dataframe = pd.read_csv(df)
57
- if glob.glob("Data/context_matrix.pkl"):
58
- pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
59
- return pickle_file
60
- else:
61
- matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
62
- pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
63
- print("Context matrix created, Run the code again")
64
- return None
65
-
66
-
67
- @timer
68
- def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
69
- dataframe = pd.read_csv(df)
70
- matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
71
- query_vector = text2embeddings(document=query, api_key=api_key, model=model)
72
- score = matrix @ query_vector
73
- top_k = np.argsort(score)[::-1][:chunk_size]
74
- context = dataframe[column].iloc[top_k]
75
- string = "\n".join(context)
76
- return string
77
-
78
-
79
- def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
80
- if model == "Ollama":
81
- return Ollama(model="phi3", temperature=temperature)
82
- elif model == "Gemini":
83
- return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
84
- else:
85
- raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
86
-
87
-
88
- def get_prompt(context:str, query:str)->str:
89
- prompt = f"""You are a chatbot named Mira. You answer user's query about ecommerce products.
90
- Here are some rules you will follow:
91
- 1. Your response will be concise and informative.
92
- 2. You must answer the user's query from the given context.
93
- 3. Do not generate any incomplete sentences or duplicate sentences.
94
-
95
- -----------------------------------------------
96
- Context: '''{context}'''
97
- -----------------------------------------------
98
- Question: '''{query}'''
99
- -----------------------------------------------
100
- """
101
- return prompt
102
-
103
- def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
104
- context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
105
- llm = select_llm(model=model, api_key=api_key, temperature=temperature)
106
- prompt = get_prompt(context, query)
107
- response = llm.stream(prompt)
 
 
108
  return response
 
1
+ import time
2
+ from typing import Any, List, Optional, Tuple, Union
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pickle
6
+ import glob
7
+ from pydantic import SecretStr
8
+ from langchain_community.llms.ollama import Ollama
9
+ from langchain_community.embeddings import OllamaEmbeddings
10
+ from langchain_google_genai.llms import GoogleGenerativeAI
11
+ from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
12
+ import streamlit as st
13
+
14
+
15
+ # Streamlit functions to select the model
16
+ def model_selection()->Union[str, Tuple[str, SecretStr]]:
17
+ select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"])
18
+ if select == "Ollama":
19
+ return select, None
20
+ elif select == "Gemini":
21
+ api_key = st.sidebar.text_input("Enter API key", type="password")
22
+ return select, api_key
23
+ else:
24
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
25
+
26
+
27
+ # This function will be used to time the functions
28
+ def timer(func)->callable:
29
+ def wrapper(*args, **kwargs):
30
+ start = time.perf_counter()
31
+ result = func(*args, **kwargs)
32
+ end = time.perf_counter()
33
+ print(f"{func.__name__} took {end - start} seconds")
34
+ return result
35
+ return wrapper
36
+
37
+ @timer
38
+ def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray:
39
+ google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
40
+ ollama_embeddings = OllamaEmbeddings(model="phi3")
41
+
42
+ if model == "Gemini":
43
+ if isinstance(document, str):
44
+ return np.array(google_embeddings.embed_query(document))
45
+ return np.array(google_embeddings.embed_documents(document))
46
+ elif model == "Ollama":
47
+ if isinstance(document, str):
48
+ return np.array(ollama_embeddings.embed_query(document))
49
+ return np.array(ollama_embeddings.embed_documents(document))
50
+ else:
51
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
52
+
53
+
54
+ @timer
55
+ def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any:
56
+ dataframe = pd.read_csv(df)
57
+ if glob.glob("Data/context_matrix.pkl"):
58
+ pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb"))
59
+ return pickle_file
60
+ else:
61
+ matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model)
62
+ pickle.dump(matrix, open("Data/context_matrix.pkl", "wb"))
63
+ print("Context matrix created, Run the code again")
64
+ return None
65
+
66
+
67
+ @timer
68
+ def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str:
69
+ dataframe = pd.read_csv(df)
70
+ matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model)
71
+ query_vector = text2embeddings(document=query, api_key=api_key, model=model)
72
+ score = matrix @ query_vector
73
+ top_k = np.argsort(score)[::-1][:chunk_size]
74
+ context = dataframe[column].iloc[top_k]
75
+ string = "\n".join(context)
76
+ return string
77
+
78
+
79
+ def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]:
80
+ if model == "Ollama":
81
+ return Ollama(model="phi3", temperature=temperature)
82
+ elif model == "Gemini":
83
+ return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key)
84
+ else:
85
+ raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.")
86
+
87
+
88
+ def get_prompt(context:str, query:str)->str:
89
+ prompt = f"""You are a chatbot named Mira. You answer user's query about ecommerce products.
90
+ Here are some rules you will follow:
91
+ 1. Your response will be concise and informative.
92
+ 2. You must answer the user's query from the given context.
93
+ 3. Do not generate any incomplete sentences or duplicate sentences.
94
+
95
+ -----------------------------------------------
96
+ Context: '''{context}'''
97
+ -----------------------------------------------
98
+ Question: '''{query}'''
99
+ -----------------------------------------------
100
+ """
101
+ return prompt
102
+
103
+ def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str:
104
+ if not api_key:
105
+ st.write("Please enter your key")
106
+ context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size)
107
+ llm = select_llm(model=model, api_key=api_key, temperature=temperature)
108
+ prompt = get_prompt(context, query)
109
+ response = llm.stream(prompt)
110
  return response