File size: 8,227 Bytes
209c441 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | import torch
import ollama
import os
from openai import OpenAI
import argparse
# ANSI escape codes for colors
PINK = '\033[95m'
CYAN = '\033[96m'
YELLOW = '\033[93m'
NEON_GREEN = '\033[92m'
RESET_COLOR = '\033[0m'
# Function to open a file and return its contents as a string
def open_file(filepath):
with open(filepath, 'r', encoding='utf-8') as infile:
return infile.read()
# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
if vault_embeddings.nelement() == 0: # Check if the tensor has any elements
return []
# Encode the rewritten input
input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
# Compute cosine similarity between the input and vault embeddings
cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
# Adjust top_k if it's greater than the number of available scores
top_k = min(top_k, len(cos_scores))
# Sort the scores and get the top-k indices
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
# Get the corresponding context from the vault
relevant_context = [vault_content[idx].strip() for idx in top_indices]
return relevant_context
# Function to interact with the Ollama model
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
# Get relevant context from the vault
relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k=3)
if relevant_context:
# Convert list to a single string with newlines between items
context_str = "\n".join(relevant_context)
print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
else:
print(CYAN + "No relevant context found." + RESET_COLOR)
# Prepare the user's input by concatenating it with the relevant context
user_input_with_context = user_input
if relevant_context:
user_input_with_context = context_str + "\n\n" + user_input
# Append the user's input to the conversation history
conversation_history.append({"role": "user", "content": user_input_with_context})
# Create a message history including the system message and the conversation history
messages = [
{"role": "system", "content": system_message},
*conversation_history
]
# Send the completion request to the Ollama model
response = client.chat.completions.create(
model=ollama_model,
messages=messages
)
# Append the model's response to the conversation history
conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
# Return the content of the response from the model
return response.choices[0].message.content
def process_text_files(user_input):
text_parse_directory = os.path.join("local-rag", "text_parse")
temp_file_path = os.path.join("local-rag", "temp.txt")
# Check if text_parse directory exists
if not os.path.exists(text_parse_directory):
print(f"Directory '{text_parse_directory}' does not exist.")
return False
# Check if temp.txt exists
if not os.path.exists(temp_file_path):
print("temp.txt does not exist.")
return False
# Read the first line of temp.txt
with open(temp_file_path, 'r', encoding='utf-8') as temp_file:
first_line = temp_file.readline().strip()
# Get all text files in the text_parse directory
text_files = [f for f in os.listdir(text_parse_directory) if f.endswith('.txt')]
# Check if the first line matches any of the text files
if f"{first_line}" not in text_files:
print(f"No matching file found for '{first_line}.txt' in text_parse directory.")
return False
# Proceed to check for the NOT FINISHED flag
file_path = os.path.join(text_parse_directory, f"{first_line}")
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Check if there are any lines after NOT FINISHED
if lines[-2].strip() == "====================NOT FINISHED====================":
print(f"'{first_line}' contains the 'NOT FINISHED' flag. Computing embeddings.")
vault_content = []
if os.path.exists(temp_file_path):
with open(temp_file_path, "r", encoding='utf-8') as vault_file:
vault_content = vault_file.readlines()
# Generate embeddings for the vault content using Ollama
vault_embeddings = []
for content in vault_content:
response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
vault_embeddings.append(response["embedding"])
# Convert to tensor and print embeddings
vault_embeddings_tensor = torch.tensor(vault_embeddings)
print("Embeddings for each line in the vault:")
print(vault_embeddings_tensor)
# Save the tensor result to a file or variable as needed
with open(os.path.join(text_parse_directory, f"{first_line}_embedding.pt"), "wb") as tensor_file:
torch.save(vault_embeddings_tensor, tensor_file)
# Remove the NOT FINISHED line from the original file
with open(file_path, 'w', encoding='utf-8') as f:
f.writelines(lines[:-1]) # Write back all lines except the NOT FINISHED line
else:
print(f"'{first_line}' does not contain the 'NOT FINISHED' flag or is already complete. Loading tensor if it exists.")
# Try to load the tensor from the corresponding file
tensor_file_path = os.path.join(text_parse_directory, f"{first_line}_embedding.pt")
if os.path.exists(tensor_file_path):
vault_embeddings_tensor = torch.load(tensor_file_path)
print("Loaded Vault Embedding Tensor:")
print(vault_embeddings_tensor)
vault_content = []
if os.path.exists(temp_file_path):
with open(temp_file_path, "r", encoding='utf-8') as vault_file:
vault_content = vault_file.readlines()
else:
print(f"No tensor file found for '{text_files}'.")
# Conversation loop
conversation_history = []
system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"
response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
return response
# # Read each file in the text_parse directory and check for the NOT FINISHED flag
# for txt_file in text_files:
# file_path = os.path.join(text_parse_directory, txt_file)
# with open(file_path, 'r', encoding='utf-8') as f:
# lines = f.readlines()
# # Check if the last line contains the "NOT FINISHED" flag
# if lines and lines[-1].strip() == "==========NOT FINISHED==========":
# print(f"'{txt_file}' contains the 'NOT FINISHED' flag. Proceeding to next step.")
# # Append the content of this file to the vault
# with open(temp_file_path, 'a', encoding='utf-8') as vault_file:
# vault_file.write('\n'.join(lines[:-1]) + '\n') # Append content without the last flag line
# else:
# print(f"'{txt_file}' does not contain the 'NOT FINISHED' flag. Skipping.")
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Ollama Chat")
parser.add_argument("--model", default="llama3", help="Ollama model to use (default: llama3)")
args = parser.parse_args()
# Configuration for the Ollama API client
client = OpenAI(
base_url='http://localhost:11434/v1',
api_key='llama3'
)
if __name__ == "__main__":
print(process_text_files("tell me about iterators"))
|