Spaces:
Runtime error
Runtime error
File size: 9,736 Bytes
5bc888a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import gradio as gr
from dotenv import load_dotenv
import traceback # For detailed error logging
import torch # Required for Hugging Face transformers
# --- LangChain and Hugging Face Transformers Imports ---
from langchain_neo4j import Neo4jGraph
# from langchain_openai import ChatOpenAI # We will replace this
from langchain_community.llms import HuggingFacePipeline # For using HuggingFace models
from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain
from langchain_core.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# --- Environment Variable Loading ---
load_dotenv()
print("Environment variables loaded:")
print(f"NEO4J_URI: {'Set' if os.getenv('NEO4J_URI') else 'Not Set'}")
print(f"NEO4J_USER: {'Set' if os.getenv('NEO4J_USER') else 'Not Set'}")
print(f"NEO4J_PASSWORD: {'Set' if os.getenv('NEO4J_PASSWORD') else 'Not Set'}")
# OPENAI_API_KEY is no longer the primary concern if using local/HF models
# print(f"OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not Set'}")
print(f"HUGGINGFACE_HUB_TOKEN: {'Set' if os.getenv('HUGGINGFACE_HUB_TOKEN') else 'Not Set (may be needed for certain models)'}")
# --- Global LangChain chain variable ---
chain = None
graph_connection_error = None # To store graph connection error
llm_initialization_error = None # To store LLM setup error
# --- Neo4j, Hugging Face LLM, and LangChain Setup ---
try:
print("Attempting to connect to Neo4j...")
graph = Neo4jGraph(
url=os.getenv("NEO4J_URI"),
username=os.getenv("NEO4J_USER"),
password=os.getenv("NEO4J_PASSWORD"),
)
print("Successfully connected to Neo4j.")
# --- Hugging Face LLM Setup ---
print("Initializing Hugging Face LLM...")
# IMPORTANT: Replace "gpt2" with your desired Hugging Face model.
# For larger models like Llama-2, ensure you have enough resources (VRAM/RAM)
# and handle authentication if it's a gated model (e.g., using huggingface-cli login
# or by passing use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN") if supported and necessary).
model_id = "gpt2" # REPLACE THIS with your chosen model, e.g., "NousResearch/Llama-2-7b-chat-hf"
# model_id = "meta-llama/Llama-2-7b-chat-hf" # Example from the prompt, requires auth and resources
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # trust_remote_code might be needed for some models
# For large models, device_map='auto' and torch_dtype are crucial.
# For smaller models like gpt2, they might not be strictly necessary or could be simplified.
hf_model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
device_map='auto', # Automatically distributes model layers across available devices (CPU/GPU)
torch_dtype=torch.float16, # Use float16 for memory efficiency if GPU supports it
# use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN") # If your model requires a token
)
hf_model.eval() # Set the model to evaluation mode
# Create a text-generation pipeline
# Adjust max_new_tokens, do_sample, top_k as needed for your model and task
pipe = pipeline(
"text-generation",
model=hf_model,
tokenizer=tokenizer,
# torch_dtype=torch.bfloat16, # Alternative dtype
# device_map="auto", # Already set in model loading
max_new_tokens=512, # Max tokens for the generated Cypher query + answer synthesis
do_sample=True,
top_k=30,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id # Often good to set for open-ended generation
)
# Wrap the pipeline in LangChain's HuggingFacePipeline
llm = HuggingFacePipeline(
pipeline=pipe,
# model_kwargs can be used to pass additional arguments to the pipeline's __call__ method
# or to the model's generate method.
model_kwargs={'temperature': 0.1, 'max_length': 2000} # max_length here includes prompt
)
print(f"Hugging Face LLM ({model_id}) initialized successfully.")
except Exception as e_llm:
llm_initialization_error_message = f"Error initializing Hugging Face LLM ({model_id}): {str(e_llm)}\n"
llm_initialization_error_message += "Full Traceback:\n" + traceback.format_exc()
print(llm_initialization_error_message)
llm_initialization_error = llm_initialization_error_message
llm = None
if llm: # Proceed only if LLM initialized successfully
# --- GraphCypherQAChain Setup ---
print("Initializing GraphCypherQAChain...")
CYPHER_GENERATION_TEMPLATE = """You are an expert Neo4j Cypher translator.
Task: Convert the natural language question into a Cypher query that can retrieve relevant information from a Neo4j graph.
Instructions:
1. Use only the provided schema details. Do not use any other node labels or relationship types.
2. Understand the question and identify the key entities and relationships.
3. Construct a Cypher query that accurately reflects the question's intent.
4. Output ONLY the Cypher query. No explanations, no introductory text, no markdown. Just the query.
Schema:
{schema}
Question: {question}
Cypher Query:"""
cypher_prompt = PromptTemplate.from_template(CYPHER_GENERATION_TEMPLATE)
# For the QA part, the default prompt is often okay, but you might want to customize it too.
# Here's an example if you choose to:
# QA_TEMPLATE = """You are an assistant that answers questions based on query results from a graph database.
# Use the provided query result to answer the question.
# If the result is empty or does not contain the answer, say so.
# Do not make up information.
# Question: {question}
# Cypher Query Result: {context}
# Answer:"""
# qa_prompt = PromptTemplate.from_template(QA_TEMPLATE)
chain = GraphCypherQAChain.from_llm(
llm=llm,
graph=graph,
verbose=True,
return_intermediate_steps=True,
cypher_prompt=cypher_prompt,
# qa_prompt=qa_prompt # Uncomment if you want to use a custom QA prompt
)
print("LangChain integration with GraphCypherQAChain initialized successfully.")
else:
# This case is now handled by the llm_initialization_error check in process_query
pass
except Exception as e_graph:
graph_connection_error_message = f"Error setting up Neo4j connection: {str(e_graph)}\n"
graph_connection_error_message += "Full Traceback:\n" + traceback.format_exc()
print(graph_connection_error_message)
graph_connection_error = graph_connection_error_message
chain = None
# --- Gradio Interface Function ---
def process_query(message: str, history: list):
if graph_connection_error:
return f"Application Initialization Error (Neo4j): {graph_connection_error}"
if llm_initialization_error:
return f"Application Initialization Error (LLM): {llm_initialization_error}"
if not chain:
return "Error: LangChain QA Chain is not available. Please check server logs for initialization issues."
print(f"Processing message: {message}")
try:
result = chain.invoke({"query": message})
print(f"Chain result: {result}")
answer = result.get("result", "No answer found or an error occurred in processing.")
intermediate_steps = result.get("intermediate_steps", [])
generated_cypher = "Could not extract Cypher query from intermediate steps."
if intermediate_steps and isinstance(intermediate_steps, list) and len(intermediate_steps) > 0:
if isinstance(intermediate_steps[0], dict) and "query" in intermediate_steps[0]:
generated_cypher = intermediate_steps[0]["query"]
# Sometimes the Cypher query might be in a different structure or a later step
# depending on the chain's verbosity and internal structure.
# You might need to inspect intermediate_steps more closely if the above doesn't work.
return f"π Generated Cypher:\n```cypher\n{generated_cypher}\n```\n\nπ¬ Answer:\n{answer}"
except Exception as e:
error_message = f"Error processing query: {str(e)}"
print(error_message)
print(traceback.format_exc())
# Specific error check for Hugging Face model issues (e.g. out of memory)
if "CUDA out of memory" in str(e):
return "LLM Error: CUDA out of memory. The model may be too large for your GPU. Try a smaller model or reduce batch size if applicable."
return error_message
# --- Gradio Interface Definition ---
print("Setting up Gradio interface...")
demo = gr.ChatInterface(
fn=process_query,
chatbot=gr.Chatbot(height=600, type="messages"),
title="Neo4j Graph Database Assistant (with Hugging Face LLM)",
description="Ask questions about your Neo4j database. Model responses depend on the chosen Hugging Face LLM.",
examples=[
"How many nodes are in the database?",
"What types of nodes exist?",
"List all relationship types.",
],
theme=gr.themes.Soft(),
cache_examples=False
)
# --- Main Execution ---
if __name__ == "__main__":
print("Launching Gradio interface...")
# To make accessible on the network (e.g., in Docker):
# demo.launch(server_name="0.0.0.0")
demo.launch() |