Sadique5's picture
Upload 23 files
d2224c7 verified
from langchain_google_genai import GoogleGenerativeAIEmbeddings
import google.generativeai as genai
import numpy as np
import os
from langchain.vectorstores import FAISS
api_key = os.environ["GOOGLE_GEMINI_API"]
def format_chat_history(chat_history):
"""Converts chat history from the provided format to the Gemini format."""
formatted_messages = []
for message in chat_history:
role = message[0]
if role == "user":
role = "user"
# For simplicity, assuming anything not "user" is the assistant
elif role =="bot": #You can expand this logic if you have other roles.
role ="model"
formatted_messages.append({"role": role, "parts": message[1]})
return formatted_messages
genai.configure(api_key=api_key)
class GeminiModel:
def __init__(self) -> None:
self.model = genai.GenerativeModel('gemini-1.5-pro-latest')
def predict(self, inp, history, grounding_threshold = 1.0):
chat = self.model.start_chat(history=format_chat_history(history))
response = chat.send_message(inp, tools ={"google_search_retrieval": {
"dynamic_retrieval_config": {
"mode": "unspecified",
"dynamic_threshold": grounding_threshold}}})
cost = (response.usage_metadata.total_token_count / 1_000_000) * 10
txt = response.text.replace('`', '').replace("\n","")
if "json" in txt[:4]:
txt = txt[4:]
return txt, cost
def generate_title(self, initial_message):
prompt = f"Generate a concise and descriptive title for the following conversation:\n\n{initial_message}\n\nTitle:"
response = self.model.generate_content(prompt)
title = response.text.strip()
return title
class GeminiEmbeddings:
def __init__(self) -> None:
self.model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
def predict(self, input):
embedding = self.model.embed_query(input)
embedding = np.array(embedding).reshape(1, -1).astype('float32')
return embedding
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) # Ensure your OpenAI API key is set in the environment
def query_all_indexes(query):
indexes_path = "faiss_indexes"
results = []
for index_dir in os.listdir(indexes_path):
index_path = os.path.join(indexes_path, index_dir)
if os.path.isdir(index_path):
# Load the FAISS vectorstore
faiss_vectorstore = FAISS.load_local(index_path, embeddings,allow_dangerous_deserialization=True)
# Perform the search query
search_results = faiss_vectorstore.similarity_search(query, k=2) # Adjust k for number of results
results.extend([(res.page_content, index_dir) for res in search_results])
return results
if __name__ == "__main__":
print(query_all_indexes("Was sagen Parteien zum Klimawandel"))