| import torch
|
| import ollama
|
| import os
|
| from openai import OpenAI
|
| import argparse
|
|
|
|
|
| PINK = '\033[95m'
|
| CYAN = '\033[96m'
|
| YELLOW = '\033[93m'
|
| NEON_GREEN = '\033[92m'
|
| RESET_COLOR = '\033[0m'
|
|
|
|
|
| def open_file(filepath):
|
| with open(filepath, 'r', encoding='utf-8') as infile:
|
| return infile.read()
|
|
|
|
|
| def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
|
| if vault_embeddings.nelement() == 0:
|
| return []
|
|
|
| input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
|
|
|
| cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
|
|
|
| top_k = min(top_k, len(cos_scores))
|
|
|
| top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
|
|
|
| relevant_context = [vault_content[idx].strip() for idx in top_indices]
|
| return relevant_context
|
|
|
|
|
| def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
|
|
|
| relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k=3)
|
| if relevant_context:
|
|
|
| context_str = "\n".join(relevant_context)
|
| print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
|
| else:
|
| print(CYAN + "No relevant context found." + RESET_COLOR)
|
|
|
|
|
| user_input_with_context = user_input
|
| if relevant_context:
|
| user_input_with_context = context_str + "\n\n" + user_input
|
|
|
|
|
| conversation_history.append({"role": "user", "content": user_input_with_context})
|
|
|
|
|
| messages = [
|
| {"role": "system", "content": system_message},
|
| *conversation_history
|
| ]
|
|
|
|
|
| response = client.chat.completions.create(
|
| model=ollama_model,
|
| messages=messages
|
| )
|
|
|
|
|
| conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
|
|
|
|
|
| return response.choices[0].message.content
|
|
|
| def process_text_files(user_input):
|
| text_parse_directory = os.path.join("local-rag", "text_parse")
|
| temp_file_path = os.path.join("local-rag", "temp.txt")
|
|
|
|
|
| if not os.path.exists(text_parse_directory):
|
| print(f"Directory '{text_parse_directory}' does not exist.")
|
| return False
|
|
|
|
|
| if not os.path.exists(temp_file_path):
|
| print("temp.txt does not exist.")
|
| return False
|
|
|
|
|
| with open(temp_file_path, 'r', encoding='utf-8') as temp_file:
|
| first_line = temp_file.readline().strip()
|
|
|
|
|
| text_files = [f for f in os.listdir(text_parse_directory) if f.endswith('.txt')]
|
|
|
|
|
| if f"{first_line}" not in text_files:
|
| print(f"No matching file found for '{first_line}.txt' in text_parse directory.")
|
| return False
|
|
|
|
|
| file_path = os.path.join(text_parse_directory, f"{first_line}")
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| lines = f.readlines()
|
|
|
|
|
|
|
| if lines[-2].strip() == "====================NOT FINISHED====================":
|
| print(f"'{first_line}' contains the 'NOT FINISHED' flag. Computing embeddings.")
|
|
|
| vault_content = []
|
| if os.path.exists(temp_file_path):
|
| with open(temp_file_path, "r", encoding='utf-8') as vault_file:
|
| vault_content = vault_file.readlines()
|
|
|
|
|
|
|
| vault_embeddings = []
|
| for content in vault_content:
|
| response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
|
| vault_embeddings.append(response["embedding"])
|
|
|
|
|
| vault_embeddings_tensor = torch.tensor(vault_embeddings)
|
| print("Embeddings for each line in the vault:")
|
| print(vault_embeddings_tensor)
|
|
|
|
|
| with open(os.path.join(text_parse_directory, f"{first_line}_embedding.pt"), "wb") as tensor_file:
|
| torch.save(vault_embeddings_tensor, tensor_file)
|
|
|
|
|
| with open(file_path, 'w', encoding='utf-8') as f:
|
| f.writelines(lines[:-1])
|
|
|
| else:
|
| print(f"'{first_line}' does not contain the 'NOT FINISHED' flag or is already complete. Loading tensor if it exists.")
|
|
|
|
|
| tensor_file_path = os.path.join(text_parse_directory, f"{first_line}_embedding.pt")
|
| if os.path.exists(tensor_file_path):
|
| vault_embeddings_tensor = torch.load(tensor_file_path)
|
| print("Loaded Vault Embedding Tensor:")
|
| print(vault_embeddings_tensor)
|
|
|
| vault_content = []
|
|
|
| if os.path.exists(temp_file_path):
|
| with open(temp_file_path, "r", encoding='utf-8') as vault_file:
|
| vault_content = vault_file.readlines()
|
|
|
| else:
|
| print(f"No tensor file found for '{text_files}'.")
|
|
|
|
|
|
|
|
|
| conversation_history = []
|
| system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
|
|
|
| response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
|
|
|
| return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| parser = argparse.ArgumentParser(description="Ollama Chat")
|
| parser.add_argument("--model", default="llama3", help="Ollama model to use (default: llama3)")
|
| args = parser.parse_args()
|
|
|
|
|
| client = OpenAI(
|
| base_url='http://localhost:11434/v1',
|
| api_key='llama3'
|
| )
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
| print(process_text_files("tell me about iterators"))
|
|
|
|
|
|
|