File size: 9,736 Bytes
5bc888a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import gradio as gr
from dotenv import load_dotenv
import traceback # For detailed error logging
import torch # Required for Hugging Face transformers

# --- LangChain and Hugging Face Transformers Imports ---
from langchain_neo4j import Neo4jGraph
# from langchain_openai import ChatOpenAI # We will replace this
from langchain_community.llms import HuggingFacePipeline # For using HuggingFace models
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_core.prompts import PromptTemplate

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# --- Environment Variable Loading ---
load_dotenv()
print("Environment variables loaded:")
print(f"NEO4J_URI: {'Set' if os.getenv('NEO4J_URI') else 'Not Set'}")
print(f"NEO4J_USER: {'Set' if os.getenv('NEO4J_USER') else 'Not Set'}")
print(f"NEO4J_PASSWORD: {'Set' if os.getenv('NEO4J_PASSWORD') else 'Not Set'}")
# OPENAI_API_KEY is no longer the primary concern if using local/HF models
# print(f"OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not Set'}")
print(f"HUGGINGFACE_HUB_TOKEN: {'Set' if os.getenv('HUGGINGFACE_HUB_TOKEN') else 'Not Set (may be needed for certain models)'}")


# --- Global LangChain chain variable ---
chain = None
graph_connection_error = None # To store graph connection error
llm_initialization_error = None # To store LLM setup error

# --- Neo4j, Hugging Face LLM, and LangChain Setup ---
try:
    print("Attempting to connect to Neo4j...")
    graph = Neo4jGraph(
        url=os.getenv("NEO4J_URI"),
        username=os.getenv("NEO4J_USER"),
        password=os.getenv("NEO4J_PASSWORD"),
    )
    print("Successfully connected to Neo4j.")

    # --- Hugging Face LLM Setup ---
    print("Initializing Hugging Face LLM...")
    # IMPORTANT: Replace "gpt2" with your desired Hugging Face model.
    # For larger models like Llama-2, ensure you have enough resources (VRAM/RAM)
    # and handle authentication if it's a gated model (e.g., using huggingface-cli login
    # or by passing use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN") if supported and necessary).
    model_id = "gpt2" # REPLACE THIS with your chosen model, e.g., "NousResearch/Llama-2-7b-chat-hf"
    # model_id = "meta-llama/Llama-2-7b-chat-hf" # Example from the prompt, requires auth and resources

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # trust_remote_code might be needed for some models
        # For large models, device_map='auto' and torch_dtype are crucial.
        # For smaller models like gpt2, they might not be strictly necessary or could be simplified.
        hf_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True,
            device_map='auto', # Automatically distributes model layers across available devices (CPU/GPU)
            torch_dtype=torch.float16, # Use float16 for memory efficiency if GPU supports it
            # use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN") # If your model requires a token
        )
        hf_model.eval() # Set the model to evaluation mode

        # Create a text-generation pipeline
        # Adjust max_new_tokens, do_sample, top_k as needed for your model and task
        pipe = pipeline(
            "text-generation",
            model=hf_model,
            tokenizer=tokenizer,
            # torch_dtype=torch.bfloat16, # Alternative dtype
            # device_map="auto", # Already set in model loading
            max_new_tokens=512,  # Max tokens for the generated Cypher query + answer synthesis
            do_sample=True,
            top_k=30,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id # Often good to set for open-ended generation
        )

        # Wrap the pipeline in LangChain's HuggingFacePipeline
        llm = HuggingFacePipeline(
            pipeline=pipe,
            # model_kwargs can be used to pass additional arguments to the pipeline's __call__ method
            # or to the model's generate method.
            model_kwargs={'temperature': 0.1, 'max_length': 2000} # max_length here includes prompt
        )
        print(f"Hugging Face LLM ({model_id}) initialized successfully.")

    except Exception as e_llm:
        llm_initialization_error_message = f"Error initializing Hugging Face LLM ({model_id}): {str(e_llm)}\n"
        llm_initialization_error_message += "Full Traceback:\n" + traceback.format_exc()
        print(llm_initialization_error_message)
        llm_initialization_error = llm_initialization_error_message
        llm = None


    if llm: # Proceed only if LLM initialized successfully
        # --- GraphCypherQAChain Setup ---
        print("Initializing GraphCypherQAChain...")
        CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher translator.
        Task: Convert the natural language question into a Cypher query that can retrieve relevant information from a Neo4j graph.
        Instructions:
        1. Use only the provided schema details. Do not use any other node labels or relationship types.
        2. Understand the question and identify the key entities and relationships.
        3. Construct a Cypher query that accurately reflects the question's intent.
        4. Output ONLY the Cypher query. No explanations, no introductory text, no markdown. Just the query.

        Schema:
        {schema}

        Question: {question}
        Cypher Query:"""
        cypher_prompt = PromptTemplate.from_template(CYPHER_GENERATION_TEMPLATE)

        # For the QA part, the default prompt is often okay, but you might want to customize it too.
        # Here's an example if you choose to:
        # QA_TEMPLATE = """You are an assistant that answers questions based on query results from a graph database.
        # Use the provided query result to answer the question.
        # If the result is empty or does not contain the answer, say so.
        # Do not make up information.
        # Question: {question}
        # Cypher Query Result: {context}
        # Answer:"""
        # qa_prompt = PromptTemplate.from_template(QA_TEMPLATE)

        chain = GraphCypherQAChain.from_llm(
            llm=llm,
            graph=graph,
            verbose=True,
            return_intermediate_steps=True,
            cypher_prompt=cypher_prompt,
            # qa_prompt=qa_prompt # Uncomment if you want to use a custom QA prompt
        )
        print("LangChain integration with GraphCypherQAChain initialized successfully.")
    else:
        # This case is now handled by the llm_initialization_error check in process_query
        pass


except Exception as e_graph:
    graph_connection_error_message = f"Error setting up Neo4j connection: {str(e_graph)}\n"
    graph_connection_error_message += "Full Traceback:\n" + traceback.format_exc()
    print(graph_connection_error_message)
    graph_connection_error = graph_connection_error_message
    chain = None

# --- Gradio Interface Function ---
def process_query(message: str, history: list):
    if graph_connection_error:
        return f"Application Initialization Error (Neo4j): {graph_connection_error}"
    if llm_initialization_error:
        return f"Application Initialization Error (LLM): {llm_initialization_error}"
    if not chain:
        return "Error: LangChain QA Chain is not available. Please check server logs for initialization issues."

    print(f"Processing message: {message}")
    try:
        result = chain.invoke({"query": message})
        print(f"Chain result: {result}")

        answer = result.get("result", "No answer found or an error occurred in processing.")
        intermediate_steps = result.get("intermediate_steps", [])
        generated_cypher = "Could not extract Cypher query from intermediate steps."

        if intermediate_steps and isinstance(intermediate_steps, list) and len(intermediate_steps) > 0:
            if isinstance(intermediate_steps[0], dict) and "query" in intermediate_steps[0]:
                generated_cypher = intermediate_steps[0]["query"]
            # Sometimes the Cypher query might be in a different structure or a later step
            # depending on the chain's verbosity and internal structure.
            # You might need to inspect intermediate_steps more closely if the above doesn't work.

        return f"πŸ“ Generated Cypher:\n```cypher\n{generated_cypher}\n```\n\nπŸ’¬ Answer:\n{answer}"

    except Exception as e:
        error_message = f"Error processing query: {str(e)}"
        print(error_message)
        print(traceback.format_exc())
        # Specific error check for Hugging Face model issues (e.g. out of memory)
        if "CUDA out of memory" in str(e):
            return "LLM Error: CUDA out of memory. The model may be too large for your GPU. Try a smaller model or reduce batch size if applicable."
        return error_message

# --- Gradio Interface Definition ---
print("Setting up Gradio interface...")
demo = gr.ChatInterface(
    fn=process_query,
    chatbot=gr.Chatbot(height=600, type="messages"),
    title="Neo4j Graph Database Assistant (with Hugging Face LLM)",
    description="Ask questions about your Neo4j database. Model responses depend on the chosen Hugging Face LLM.",
    examples=[
        "How many nodes are in the database?",
        "What types of nodes exist?",
        "List all relationship types.",
    ],
    theme=gr.themes.Soft(),
    cache_examples=False
)

# --- Main Execution ---
if __name__ == "__main__":
    print("Launching Gradio interface...")
    # To make accessible on the network (e.g., in Docker):
    # demo.launch(server_name="0.0.0.0")
    demo.launch()