File size: 5,471 Bytes
2fd14d8
 
 
 
56caa41
2fd14d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import glob
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
import logging

# Configure logger
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('neuronap.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

load_dotenv()
os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY", "")
if not os.environ["GOOGLE_API_KEY"]:
    logger.error("GEMINI_API_KEY not found in .env file")
    raise ValueError("GEMINI_API_KEY not found in .env file")
logger.info("Loaded environment variables")

# Initialize embeddings and text splitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
try:
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    logger.info("Initialized sentence-transformers embeddings")
except Exception as e:
    logger.error(f"Failed to initialize embeddings: {e}")
    raise

def create_knowledge_embeddings():
    logger.info("Starting creation of knowledge embeddings")
    knowledge_files = glob.glob("knowledge/*.pdf")
    if not knowledge_files:
        logger.error("No .pdf files found in 'knowledge' folder")
        raise FileNotFoundError("No .pdf files found in 'knowledge' folder.")
    
    all_docs = []
    for file_path in knowledge_files:
        try:
            logger.info(f"Loading PDF: {file_path}")
            loader = PyPDFLoader(file_path)
            documents = loader.load()
            texts = text_splitter.split_documents(documents)
            all_docs.extend(texts)
            logger.info(f"Successfully loaded and split {file_path} with {len(texts)} chunks")
        except Exception as e:
            logger.error(f"Error loading {file_path}: {e}")
    
    if not all_docs:
        logger.error("No valid documents loaded from knowledge folder")
        raise ValueError("No valid documents loaded from knowledge folder.")
    
    logger.info(f"Creating FAISS index with {len(all_docs)} document chunks")
    try:
        vectorstore = FAISS.from_documents(all_docs, embeddings)
        vectorstore.save_local("faiss_index")
        logger.info("FAISS embeddings created and saved to 'faiss_index'")
    except Exception as e:
        logger.error(f"Error creating FAISS index: {e}")
        raise

def load_vectorstore():
    logger.info("Loading FAISS vectorstore")
    if not os.path.exists("faiss_index"):
        logger.error("FAISS index not found")
        raise FileNotFoundError("FAISS index not found. Run create_knowledge_embeddings first.")
    vectorstore = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
    logger.info("FAISS vectorstore loaded successfully")
    return vectorstore

def update_vectorstore_with_user_results(user_results):
    logger.info("Updating vectorstore with user results")
    vectorstore = load_vectorstore()
    with open("user_results.txt", "w", encoding="utf-8") as f:
        f.write(user_results)
    user_loader = TextLoader("user_results.txt")
    user_docs = user_loader.load()
    user_texts = text_splitter.split_documents(user_docs)
    vectorstore.add_documents(user_texts)
    vectorstore.save_local("faiss_index")
    logger.info("User results added to vectorstore and saved")
    return vectorstore

def setup_rag():
    logger.info("Setting up RAG chain")
    try:
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.5)
        logger.info("Initialized Gemini 2.0 Flash LLM")
    except Exception as e:
        logger.error(f"Error initializing LLM: {e}. Check GEMINI_API_KEY or quota")
        raise
    vectorstore = load_vectorstore()
    prompt_template = """\
You are an expert assistant in EEG-based sleep analysis. Use the provided context from sleep stage definitions, EEG characteristics, and user-specific results to answer the query.
Context: {context}
User Query: {question}
Provide a clear, educational response in plain text, avoiding Markdown symbols (e.g., **, *). Use bullet points (•) for lists and exact values from user data if referenced. Include concise explanations for each metric or finding.
"""
    PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        return_source_documents=True,
        chain_type_kwargs={"prompt": PROMPT}
    )
    logger.info("RAG chain setup completed")
    return qa_chain

def chat_with_llm(query, qa_chain):
    logger.info(f"Processing query: {query}")
    try:
        result = qa_chain({"query": query})["result"]
        logger.info("Query processed successfully")
        return result
    except Exception as e:
        logger.error(f"Error processing query: {e}")
        return f"Error processing query: {e}. Check GEMINI_API_KEY or quota"

if __name__ == "__main__":
    logger.info("Running rag.py as main script")
    os.makedirs("knowledge", exist_ok=True)
    create_knowledge_embeddings()