File size: 8,227 Bytes
209c441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
import ollama
import os
from openai import OpenAI
import argparse

# ANSI escape codes for colors
PINK = '\033[95m'
CYAN = '\033[96m'
YELLOW = '\033[93m'
NEON_GREEN = '\033[92m'
RESET_COLOR = '\033[0m'

# Function to open a file and return its contents as a string
def open_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as infile:
        return infile.read()

# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
    if vault_embeddings.nelement() == 0:  # Check if the tensor has any elements
        return []
    # Encode the rewritten input
    input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
    # Compute cosine similarity between the input and vault embeddings
    cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
    # Adjust top_k if it's greater than the number of available scores
    top_k = min(top_k, len(cos_scores))
    # Sort the scores and get the top-k indices
    top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
    # Get the corresponding context from the vault
    relevant_context = [vault_content[idx].strip() for idx in top_indices]
    return relevant_context

# Function to interact with the Ollama model
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
    # Get relevant context from the vault
    relevant_context = get_relevant_context(user_input, vault_embeddings, vault_content, top_k=3)
    if relevant_context:
        # Convert list to a single string with newlines between items
        context_str = "\n".join(relevant_context)
        print("Context Pulled from Documents: \n\n" + CYAN + context_str + RESET_COLOR)
    else:
        print(CYAN + "No relevant context found." + RESET_COLOR)
    
    # Prepare the user's input by concatenating it with the relevant context
    user_input_with_context = user_input
    if relevant_context:
        user_input_with_context = context_str + "\n\n" + user_input
    
    # Append the user's input to the conversation history
    conversation_history.append({"role": "user", "content": user_input_with_context})
    
    # Create a message history including the system message and the conversation history
    messages = [
        {"role": "system", "content": system_message},
        *conversation_history
    ]
    
    # Send the completion request to the Ollama model
    response = client.chat.completions.create(
        model=ollama_model,
        messages=messages
    )
    
    # Append the model's response to the conversation history
    conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
    
    # Return the content of the response from the model
    return response.choices[0].message.content

def process_text_files(user_input):
    text_parse_directory = os.path.join("local-rag", "text_parse")
    temp_file_path = os.path.join("local-rag", "temp.txt")

    # Check if text_parse directory exists
    if not os.path.exists(text_parse_directory):
        print(f"Directory '{text_parse_directory}' does not exist.")
        return False

    # Check if temp.txt exists
    if not os.path.exists(temp_file_path):
        print("temp.txt does not exist.")
        return False
    
    # Read the first line of temp.txt
    with open(temp_file_path, 'r', encoding='utf-8') as temp_file:
        first_line = temp_file.readline().strip()

    # Get all text files in the text_parse directory
    text_files = [f for f in os.listdir(text_parse_directory) if f.endswith('.txt')]
    
    # Check if the first line matches any of the text files
    if f"{first_line}" not in text_files:
        print(f"No matching file found for '{first_line}.txt' in text_parse directory.")
        return False

    # Proceed to check for the NOT FINISHED flag
    file_path = os.path.join(text_parse_directory, f"{first_line}")
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()


    # Check if there are any lines after NOT FINISHED
    if lines[-2].strip() == "====================NOT FINISHED====================":
        print(f"'{first_line}' contains the 'NOT FINISHED' flag. Computing embeddings.")

        vault_content = []
        if os.path.exists(temp_file_path):
            with open(temp_file_path, "r", encoding='utf-8') as vault_file:
                vault_content = vault_file.readlines()


        # Generate embeddings for the vault content using Ollama
        vault_embeddings = []
        for content in vault_content:
            response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
            vault_embeddings.append(response["embedding"])

        # Convert to tensor and print embeddings
        vault_embeddings_tensor = torch.tensor(vault_embeddings) 
        print("Embeddings for each line in the vault:")
        print(vault_embeddings_tensor)
        
        # Save the tensor result to a file or variable as needed
        with open(os.path.join(text_parse_directory, f"{first_line}_embedding.pt"), "wb") as tensor_file:
            torch.save(vault_embeddings_tensor, tensor_file)

        # Remove the NOT FINISHED line from the original file
        with open(file_path, 'w', encoding='utf-8') as f:
            f.writelines(lines[:-1])  # Write back all lines except the NOT FINISHED line

    else:
        print(f"'{first_line}' does not contain the 'NOT FINISHED' flag or is already complete. Loading tensor if it exists.")

        # Try to load the tensor from the corresponding file
        tensor_file_path = os.path.join(text_parse_directory, f"{first_line}_embedding.pt")
        if os.path.exists(tensor_file_path):
            vault_embeddings_tensor = torch.load(tensor_file_path)
            print("Loaded Vault Embedding Tensor:")
            print(vault_embeddings_tensor)

            vault_content = []
            
            if os.path.exists(temp_file_path):
                with open(temp_file_path, "r", encoding='utf-8') as vault_file:
                    vault_content = vault_file.readlines()

        else:
            print(f"No tensor file found for '{text_files}'.")

    
    
     # Conversation loop
    conversation_history = []
    system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text"

    response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, args.model, conversation_history)
    
    return response


    # # Read each file in the text_parse directory and check for the NOT FINISHED flag
    # for txt_file in text_files:
    #     file_path = os.path.join(text_parse_directory, txt_file)
    #     with open(file_path, 'r', encoding='utf-8') as f:
    #         lines = f.readlines()
    #         # Check if the last line contains the "NOT FINISHED" flag
    #         if lines and lines[-1].strip() == "==========NOT FINISHED==========":
    #             print(f"'{txt_file}' contains the 'NOT FINISHED' flag. Proceeding to next step.")
    #             # Append the content of this file to the vault
    #             with open(temp_file_path, 'a', encoding='utf-8') as vault_file:
    #                 vault_file.write('\n'.join(lines[:-1]) + '\n')  # Append content without the last flag line
    #         else:
    #             print(f"'{txt_file}' does not contain the 'NOT FINISHED' flag. Skipping.")

# Parse command-line arguments
parser = argparse.ArgumentParser(description="Ollama Chat")
parser.add_argument("--model", default="llama3", help="Ollama model to use (default: llama3)")
args = parser.parse_args()

# Configuration for the Ollama API client
client = OpenAI(
    base_url='http://localhost:11434/v1',
    api_key='llama3'
)

if __name__ == "__main__":
    
    
    print(process_text_files("tell me about iterators"))