Spaces:
Sleeping
Sleeping
| import sys | |
| import time | |
| from typing import Any, List, Optional, Tuple, Union | |
| import numpy as np | |
| import pandas as pd | |
| import pickle | |
| import glob | |
| from pydantic import SecretStr | |
| from langchain_community.llms.ollama import Ollama | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| from langchain_google_genai.llms import GoogleGenerativeAI | |
| from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings | |
| import streamlit as st | |
| # Streamlit functions to select the model | |
| def model_selection()->Union[str, Tuple[str, SecretStr]]: | |
| select = st.sidebar.selectbox("Select model", options=["Ollama", "Gemini"]) | |
| if select == "Ollama": | |
| return select, None | |
| elif select == "Gemini": | |
| api_key = st.sidebar.text_input("Enter API key", type="password") | |
| return select, api_key | |
| else: | |
| raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") | |
| # This function will be used to time the functions | |
| def timer(func)->callable: | |
| def wrapper(*args, **kwargs): | |
| start = time.perf_counter() | |
| result = func(*args, **kwargs) | |
| end = time.perf_counter() | |
| print(f"{func.__name__} took {end - start} seconds") | |
| return result | |
| return wrapper | |
| def text2embeddings(document: Union[str, List[str]], api_key:str, model:str) -> np.ndarray: | |
| google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) | |
| ollama_embeddings = OllamaEmbeddings(model="phi3") | |
| if model == "Gemini": | |
| if isinstance(document, str): | |
| return np.array(google_embeddings.embed_query(document)) | |
| return np.array(google_embeddings.embed_documents(document)) | |
| elif model == "Ollama": | |
| if isinstance(document, str): | |
| return np.array(ollama_embeddings.embed_query(document)) | |
| return np.array(ollama_embeddings.embed_documents(document)) | |
| else: | |
| raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") | |
| def build_context_matrix(df:str, column:str, api_key:str, model:str)->Any: | |
| dataframe = pd.read_csv(df) | |
| if glob.glob("Data/context_matrix.pkl"): | |
| pickle_file = pickle.load(open("Data/context_matrix.pkl", "rb")) | |
| return pickle_file | |
| else: | |
| matrix = text2embeddings(document=dataframe[column].tolist(), api_key=api_key, model=model) | |
| pickle.dump(matrix, open("Data/context_matrix.pkl", "wb")) | |
| print("Context matrix created, Run the code again") | |
| return None | |
| def semantic_chunk(query:str, df:str, column:str, api_key:str, model:str, chunk_size:int)->str: | |
| dataframe = pd.read_csv(df) | |
| matrix = build_context_matrix(df=df, column=column, api_key=api_key, model=model) | |
| query_vector = text2embeddings(document=query, api_key=api_key, model=model) | |
| score = matrix @ query_vector | |
| top_k = np.argsort(score)[::-1][:chunk_size] | |
| context = dataframe[column].iloc[top_k] | |
| string = "\n".join(context) | |
| return string | |
| def select_llm(model:str, api_key:str=None, temperature: float=0.5)->Union[Ollama, GoogleGenerativeAI]: | |
| if model == "Ollama": | |
| return Ollama(model="phi3", temperature=temperature) | |
| elif model == "Gemini": | |
| return GoogleGenerativeAI(model="gemini-pro", temperature=temperature, google_api_key=api_key) | |
| else: | |
| raise ValueError("Invalid model name. Please choose 'Gemini' or 'Ollama'.") | |
| def get_prompt(context:str, query:str)->str: | |
| prompt = f"""You are a chatbot named Mira. You answer user's queries about e-commerce products. | |
| Here are some rules you will follow: | |
| 1. Your response will be concise and informative. | |
| 2. Do not generate any incomplete sentences or duplicate sentences. | |
| 3. If you give a list of products, do not list more than 5 products at once. | |
| ----------------------------------------------- | |
| Context: '''{context}''' | |
| ----------------------------------------------- | |
| Question: '''{query}''' | |
| ----------------------------------------------- | |
| """ | |
| return prompt | |
| def retrieve_query(query:str, df:str, column:str, api_key:str, model:str, temperature:float=0.5, chunk_size:int=100)->str: | |
| if not api_key: | |
| st.write("Please enter your key") | |
| sys.exit(0) | |
| context = semantic_chunk(query=query, df=df, column=column, model=model, api_key=api_key, chunk_size=chunk_size) | |
| llm = select_llm(model=model, api_key=api_key, temperature=temperature) | |
| prompt = get_prompt(context, query) | |
| response = llm.stream(prompt) | |
| return response |