|
|
|
|
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
|
import google.generativeai as genai
|
|
|
import numpy as np
|
|
|
import os
|
|
|
from langchain.vectorstores import FAISS
|
|
|
api_key = os.environ["GOOGLE_GEMINI_API"]
|
|
|
|
|
|
def format_chat_history(chat_history):
|
|
|
"""Converts chat history from the provided format to the Gemini format."""
|
|
|
formatted_messages = []
|
|
|
for message in chat_history:
|
|
|
role = message[0]
|
|
|
if role == "user":
|
|
|
role = "user"
|
|
|
|
|
|
elif role =="bot":
|
|
|
role ="model"
|
|
|
|
|
|
formatted_messages.append({"role": role, "parts": message[1]})
|
|
|
return formatted_messages
|
|
|
|
|
|
genai.configure(api_key=api_key)
|
|
|
class GeminiModel:
|
|
|
def __init__(self) -> None:
|
|
|
self.model = genai.GenerativeModel('gemini-1.5-pro-latest')
|
|
|
|
|
|
def predict(self, inp, history, grounding_threshold = 1.0):
|
|
|
chat = self.model.start_chat(history=format_chat_history(history))
|
|
|
response = chat.send_message(inp, tools ={"google_search_retrieval": {
|
|
|
"dynamic_retrieval_config": {
|
|
|
"mode": "unspecified",
|
|
|
"dynamic_threshold": grounding_threshold}}})
|
|
|
|
|
|
cost = (response.usage_metadata.total_token_count / 1_000_000) * 10
|
|
|
txt = response.text.replace('`', '').replace("\n","")
|
|
|
if "json" in txt[:4]:
|
|
|
txt = txt[4:]
|
|
|
return txt, cost
|
|
|
|
|
|
def generate_title(self, initial_message):
|
|
|
prompt = f"Generate a concise and descriptive title for the following conversation:\n\n{initial_message}\n\nTitle:"
|
|
|
response = self.model.generate_content(prompt)
|
|
|
title = response.text.strip()
|
|
|
return title
|
|
|
|
|
|
class GeminiEmbeddings:
|
|
|
def __init__(self) -> None:
|
|
|
self.model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
|
|
|
|
|
|
def predict(self, input):
|
|
|
embedding = self.model.embed_query(input)
|
|
|
embedding = np.array(embedding).reshape(1, -1).astype('float32')
|
|
|
return embedding
|
|
|
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key)
|
|
|
|
|
|
def query_all_indexes(query):
|
|
|
indexes_path = "faiss_indexes"
|
|
|
results = []
|
|
|
|
|
|
for index_dir in os.listdir(indexes_path):
|
|
|
index_path = os.path.join(indexes_path, index_dir)
|
|
|
if os.path.isdir(index_path):
|
|
|
|
|
|
faiss_vectorstore = FAISS.load_local(index_path, embeddings,allow_dangerous_deserialization=True)
|
|
|
|
|
|
|
|
|
search_results = faiss_vectorstore.similarity_search(query, k=2)
|
|
|
results.extend([(res.page_content, index_dir) for res in search_results])
|
|
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print(query_all_indexes("Was sagen Parteien zum Klimawandel")) |