Spaces:
No application file
No application file
Upload folder src
Browse files- src/SpeechToText/__init__.py +0 -0
- src/SpeechToText/hamsa.py +142 -0
- src/SpeechToText/sr.py +84 -0
- src/TextToSpeech/__init__.py +0 -0
- src/TextToSpeech/gtts_testing.mp3 +0 -0
- src/TextToSpeech/gtts_tts.py +90 -0
- src/TextToSpeech/hamsa.py +0 -0
- src/__init__.py +0 -0
- src/agenticRAG/__init__.py +0 -0
- src/agenticRAG/components/__init__.py +0 -0
- src/agenticRAG/components/document_parsing.py +214 -0
- src/agenticRAG/components/embeddings.py +62 -0
- src/agenticRAG/components/llm_factory.py +66 -0
- src/agenticRAG/components/search_tools.py +71 -0
- src/agenticRAG/components/vectorstore.py +297 -0
- src/agenticRAG/gpt.py +340 -0
- src/agenticRAG/graph/__init__.py +0 -0
- src/agenticRAG/graph/builder.py +50 -0
- src/agenticRAG/graph/router.py +11 -0
- src/agenticRAG/main.py +105 -0
- src/agenticRAG/models/__init__.py +0 -0
- src/agenticRAG/models/schemas.py +25 -0
- src/agenticRAG/models/state.py +17 -0
- src/agenticRAG/nodes/__init__.py +0 -0
- src/agenticRAG/nodes/direct_llm_node.py +33 -0
- src/agenticRAG/nodes/query_router.py +41 -0
- src/agenticRAG/nodes/query_upgrader.py +40 -0
- src/agenticRAG/nodes/rag_node.py +48 -0
- src/agenticRAG/nodes/web_search_node.py +44 -0
- src/agenticRAG/prompt/__init__.py +0 -0
- src/agenticRAG/prompt/prompts.py +159 -0
- src/config/__init__.py +0 -0
- src/config/settings.py +49 -0
src/SpeechToText/__init__.py
ADDED
|
File without changes
|
src/SpeechToText/hamsa.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import base64
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import json
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import os
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def transcribe_audio_hamsa(audio, language, history):
|
| 11 |
+
"""
|
| 12 |
+
Transcribe audio using Hamsa API
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
audio: Audio file path or audio data
|
| 16 |
+
language: Selected language from dropdown
|
| 17 |
+
history: Previous transcription history
|
| 18 |
+
api_key: Hamsa API key
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
tuple: (updated_history, transcribed_text)
|
| 22 |
+
"""
|
| 23 |
+
api_key = os.getenv("HAMS_API_KEY")
|
| 24 |
+
if not api_key:
|
| 25 |
+
raise ValueError("HAMS_API_KEY not set in environment variables")
|
| 26 |
+
|
| 27 |
+
if audio is None:
|
| 28 |
+
return history, ""
|
| 29 |
+
|
| 30 |
+
# Language codes for Hamsa API
|
| 31 |
+
language_codes = {
|
| 32 |
+
"English": "en",
|
| 33 |
+
"Arabic": "ar",
|
| 34 |
+
"Arabic (Egypt)": "ar",
|
| 35 |
+
"Arabic (UAE)": "ar",
|
| 36 |
+
"Arabic (Lebanon)": "ar",
|
| 37 |
+
"Arabic (Saudi Arabia)": "ar",
|
| 38 |
+
"Arabic (Kuwait)": "ar",
|
| 39 |
+
"Arabic (Qatar)": "ar",
|
| 40 |
+
"Arabic (Jordan)": "ar",
|
| 41 |
+
"Auto-detect": "auto" # You may need to check if Hamsa supports auto-detection
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
# Convert audio file to base64
|
| 46 |
+
if isinstance(audio, str): # If audio is a file path
|
| 47 |
+
with open(audio, 'rb') as audio_file:
|
| 48 |
+
audio_bytes = audio_file.read()
|
| 49 |
+
else: # If audio is already bytes
|
| 50 |
+
audio_bytes = audio
|
| 51 |
+
|
| 52 |
+
# Encode audio to base64
|
| 53 |
+
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
| 54 |
+
|
| 55 |
+
# Get selected language code
|
| 56 |
+
selected_language = language_codes.get(language, "ar")
|
| 57 |
+
|
| 58 |
+
# Prepare API request
|
| 59 |
+
url = "https://api.tryhamsa.com/v1/realtime/stt"
|
| 60 |
+
payload = {
|
| 61 |
+
"audioList": [], # Empty for single audio file
|
| 62 |
+
"audioBase64": audio_base64,
|
| 63 |
+
"language": selected_language,
|
| 64 |
+
"isEosEnabled": False,
|
| 65 |
+
"eosThreshold": 0.3
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
headers = {
|
| 69 |
+
"Authorization": api_key,
|
| 70 |
+
"Content-Type": "application/json"
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
# Make API request
|
| 74 |
+
response = requests.post(url, json=payload, headers=headers)
|
| 75 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
| 76 |
+
|
| 77 |
+
# Parse response
|
| 78 |
+
result = response.json()
|
| 79 |
+
text = result.get("text", "")
|
| 80 |
+
|
| 81 |
+
# Handle auto-detection result formatting
|
| 82 |
+
if language == "Auto-detect" and text:
|
| 83 |
+
# You might want to add language detection info if Hamsa provides it
|
| 84 |
+
text = f"[Auto-detected] {text}"
|
| 85 |
+
|
| 86 |
+
# Add timestamp and transcription to history
|
| 87 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 88 |
+
new_entry = f"[{timestamp}] [{language}] {text}"
|
| 89 |
+
|
| 90 |
+
# Update history
|
| 91 |
+
if history:
|
| 92 |
+
updated_history = history + "\n" + new_entry
|
| 93 |
+
else:
|
| 94 |
+
updated_history = new_entry
|
| 95 |
+
|
| 96 |
+
return updated_history, text
|
| 97 |
+
|
| 98 |
+
except requests.exceptions.RequestException as e:
|
| 99 |
+
error_msg = f"API request failed: {e}"
|
| 100 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 101 |
+
new_entry = f"[{timestamp}] [{language}] ERROR: {error_msg}"
|
| 102 |
+
|
| 103 |
+
if history:
|
| 104 |
+
updated_history = history + "\n" + new_entry
|
| 105 |
+
else:
|
| 106 |
+
updated_history = new_entry
|
| 107 |
+
|
| 108 |
+
return updated_history, error_msg
|
| 109 |
+
|
| 110 |
+
except json.JSONDecodeError as e:
|
| 111 |
+
error_msg = f"Failed to parse API response: {e}"
|
| 112 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 113 |
+
new_entry = f"[{timestamp}] [{language}] ERROR: {error_msg}"
|
| 114 |
+
|
| 115 |
+
if history:
|
| 116 |
+
updated_history = history + "\n" + new_entry
|
| 117 |
+
else:
|
| 118 |
+
updated_history = new_entry
|
| 119 |
+
|
| 120 |
+
return updated_history, error_msg
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
error_msg = f"Unexpected error: {e}"
|
| 124 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 125 |
+
new_entry = f"[{timestamp}] [{language}] ERROR: {error_msg}"
|
| 126 |
+
|
| 127 |
+
if history:
|
| 128 |
+
updated_history = history + "\n" + new_entry
|
| 129 |
+
else:
|
| 130 |
+
updated_history = new_entry
|
| 131 |
+
|
| 132 |
+
return updated_history, error_msg
|
| 133 |
+
|
| 134 |
+
def clear_history():
|
| 135 |
+
"""Clear the transcription history"""
|
| 136 |
+
return "", ""
|
| 137 |
+
|
| 138 |
+
# Example usage:
|
| 139 |
+
# api_key = "your-hamsa-api-key-here"
|
| 140 |
+
# history, text = transcribe_audio("path/to/audio.wav", "Arabic", "", api_key)
|
| 141 |
+
# print(f"Transcribed text: {text}")
|
| 142 |
+
# print(f"History: {history}")
|
src/SpeechToText/sr.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import speech_recognition as sr
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
def transcribe_audio(audio, language, history):
|
| 5 |
+
recognizer = sr.Recognizer()
|
| 6 |
+
|
| 7 |
+
if audio is None:
|
| 8 |
+
return history, ""
|
| 9 |
+
|
| 10 |
+
# Language codes for Google Speech Recognition
|
| 11 |
+
language_codes = {
|
| 12 |
+
"English": "en-US",
|
| 13 |
+
"Arabic": "ar-SA", # Saudi Arabic
|
| 14 |
+
"Arabic (Egypt)": "ar-EG",
|
| 15 |
+
"Arabic (UAE)": "ar-AE",
|
| 16 |
+
"Arabic (Lebanon)": "ar-LB",
|
| 17 |
+
"Arabic (Saudi Arabia)": "ar-SA",
|
| 18 |
+
"Arabic (Kuwait)": "ar-KW",
|
| 19 |
+
"Arabic (Qatar)": "ar-QA",
|
| 20 |
+
"Arabic (Jordan)": "ar-JO",
|
| 21 |
+
"Auto-detect": None # Let Google auto-detect
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
with sr.AudioFile(audio) as source:
|
| 26 |
+
# Adjust for ambient noise
|
| 27 |
+
recognizer.adjust_for_ambient_noise(source)
|
| 28 |
+
audio_data = recognizer.record(source)
|
| 29 |
+
|
| 30 |
+
# Get selected language code
|
| 31 |
+
selected_language = language_codes.get(language, "en-US")
|
| 32 |
+
|
| 33 |
+
# Transcribe based on language selection
|
| 34 |
+
if selected_language:
|
| 35 |
+
text = recognizer.recognize_google(audio_data, language=selected_language)
|
| 36 |
+
else:
|
| 37 |
+
# Auto-detect: try Arabic first, then English
|
| 38 |
+
try:
|
| 39 |
+
text = recognizer.recognize_google(audio_data, language="ar-SA")
|
| 40 |
+
detected_lang = "Arabic"
|
| 41 |
+
except:
|
| 42 |
+
text = recognizer.recognize_google(audio_data, language="en-US")
|
| 43 |
+
detected_lang = "English"
|
| 44 |
+
|
| 45 |
+
text = f"[{detected_lang}] {text}"
|
| 46 |
+
|
| 47 |
+
# Add timestamp and transcription to history
|
| 48 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 49 |
+
new_entry = f"[{timestamp}] [{language}] {text}"
|
| 50 |
+
|
| 51 |
+
# Update history
|
| 52 |
+
if history:
|
| 53 |
+
updated_history = history + "\n" + new_entry
|
| 54 |
+
else:
|
| 55 |
+
updated_history = new_entry
|
| 56 |
+
|
| 57 |
+
return updated_history, text
|
| 58 |
+
|
| 59 |
+
except sr.UnknownValueError:
|
| 60 |
+
error_msg = "Could not understand audio"
|
| 61 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 62 |
+
new_entry = f"[{timestamp}] [{language}] ERROR: {error_msg}"
|
| 63 |
+
|
| 64 |
+
if history:
|
| 65 |
+
updated_history = history + "\n" + new_entry
|
| 66 |
+
else:
|
| 67 |
+
updated_history = new_entry
|
| 68 |
+
|
| 69 |
+
return updated_history, error_msg
|
| 70 |
+
|
| 71 |
+
except sr.RequestError as e:
|
| 72 |
+
error_msg = f"Could not request results; {e}"
|
| 73 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 74 |
+
new_entry = f"[{timestamp}] [{language}] ERROR: {error_msg}"
|
| 75 |
+
|
| 76 |
+
if history:
|
| 77 |
+
updated_history = history + "\n" + new_entry
|
| 78 |
+
else:
|
| 79 |
+
updated_history = new_entry
|
| 80 |
+
|
| 81 |
+
return updated_history, error_msg
|
| 82 |
+
|
| 83 |
+
def clear_history():
|
| 84 |
+
return "", ""
|
src/TextToSpeech/__init__.py
ADDED
|
File without changes
|
src/TextToSpeech/gtts_testing.mp3
ADDED
|
Binary file (49.5 kB). View file
|
|
|
src/TextToSpeech/gtts_tts.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from gtts import gTTS
|
| 3 |
+
|
| 4 |
+
def text_to_speech_with_gtts_old(input_text, output_filepath, language="en"):
|
| 5 |
+
|
| 6 |
+
audioobj= gTTS(
|
| 7 |
+
text=input_text,
|
| 8 |
+
lang=language,
|
| 9 |
+
slow=False
|
| 10 |
+
)
|
| 11 |
+
audioobj.save(output_filepath)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# input_text="Hi this is AI with Hassan for testing purpose!"
|
| 15 |
+
input_text= "مرحبًا، هذا ذكاء اصطناعي مع حسن لغرض الاختبار!"
|
| 16 |
+
text_to_speech_with_gtts_old(input_text=input_text, output_filepath="gtts_testing.mp3", language="ar")
|
| 17 |
+
|
| 18 |
+
#Step1b: Setup Text to Speech–TTS–model with ElevenLabs
|
| 19 |
+
import elevenlabs
|
| 20 |
+
from elevenlabs.client import ElevenLabs
|
| 21 |
+
|
| 22 |
+
ELEVENLABS_API_KEY=os.environ.get("ELEVEN_API_KEY")
|
| 23 |
+
|
| 24 |
+
def text_to_speech_with_elevenlabs_old(input_text, output_filepath):
|
| 25 |
+
client=ElevenLabs(api_key=ELEVENLABS_API_KEY)
|
| 26 |
+
audio=client.generate(
|
| 27 |
+
text= input_text,
|
| 28 |
+
voice= "Aria",
|
| 29 |
+
output_format= "mp3_22050_32",
|
| 30 |
+
model= "eleven_turbo_v2"
|
| 31 |
+
)
|
| 32 |
+
elevenlabs.save(audio, output_filepath)
|
| 33 |
+
|
| 34 |
+
#text_to_speech_with_elevenlabs_old(input_text, output_filepath="elevenlabs_testing.mp3")
|
| 35 |
+
|
| 36 |
+
#Step2: Use Model for Text output to Voice
|
| 37 |
+
|
| 38 |
+
import subprocess
|
| 39 |
+
import platform
|
| 40 |
+
|
| 41 |
+
def text_to_speech_with_gtts(input_text, output_filepath):
|
| 42 |
+
language="en"
|
| 43 |
+
|
| 44 |
+
audioobj= gTTS(
|
| 45 |
+
text=input_text,
|
| 46 |
+
lang=language,
|
| 47 |
+
slow=False
|
| 48 |
+
)
|
| 49 |
+
audioobj.save(output_filepath)
|
| 50 |
+
os_name = platform.system()
|
| 51 |
+
try:
|
| 52 |
+
if os_name == "Darwin": # macOS
|
| 53 |
+
subprocess.run(['afplay', output_filepath])
|
| 54 |
+
elif os_name == "Windows": # Windows
|
| 55 |
+
subprocess.run(['powershell', '-c', f'(New-Object Media.SoundPlayer "{output_filepath}").PlaySync();'])
|
| 56 |
+
elif os_name == "Linux": # Linux
|
| 57 |
+
subprocess.run(['aplay', output_filepath]) # Alternative: use 'mpg123' or 'ffplay'
|
| 58 |
+
else:
|
| 59 |
+
raise OSError("Unsupported operating system")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"An error occurred while trying to play the audio: {e}")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
input_text="Hi this is Ai with Hassan, autoplay testing!"
|
| 65 |
+
#text_to_speech_with_gtts(input_text=input_text, output_filepath="gtts_testing_autoplay.mp3")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def text_to_speech_with_elevenlabs(input_text, output_filepath):
|
| 69 |
+
client=ElevenLabs(api_key=ELEVENLABS_API_KEY)
|
| 70 |
+
audio=client.generate(
|
| 71 |
+
text= input_text,
|
| 72 |
+
voice= "Aria",
|
| 73 |
+
output_format= "mp3_22050_32",
|
| 74 |
+
model= "eleven_turbo_v2"
|
| 75 |
+
)
|
| 76 |
+
elevenlabs.save(audio, output_filepath)
|
| 77 |
+
os_name = platform.system()
|
| 78 |
+
try:
|
| 79 |
+
if os_name == "Darwin": # macOS
|
| 80 |
+
subprocess.run(['afplay', output_filepath])
|
| 81 |
+
elif os_name == "Windows": # Windows
|
| 82 |
+
subprocess.run(['powershell', '-c', f'(New-Object Media.SoundPlayer "{output_filepath}").PlaySync();'])
|
| 83 |
+
elif os_name == "Linux": # Linux
|
| 84 |
+
subprocess.run(['aplay', output_filepath]) # Alternative: use 'mpg123' or 'ffplay'
|
| 85 |
+
else:
|
| 86 |
+
raise OSError("Unsupported operating system")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"An error occurred while trying to play the audio: {e}")
|
| 89 |
+
|
| 90 |
+
#text_to_speech_with_elevenlabs(input_text, output_filepath="elevenlabs_testing_autoplay.mp3")
|
src/TextToSpeech/hamsa.py
ADDED
|
File without changes
|
src/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/components/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/components/document_parsing.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# LangChain imports
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 7 |
+
from langchain_community.document_loaders import (
|
| 8 |
+
PyPDFLoader,
|
| 9 |
+
Docx2txtLoader,
|
| 10 |
+
TextLoader,
|
| 11 |
+
UnstructuredMarkdownLoader
|
| 12 |
+
)
|
| 13 |
+
from langchain.schema import Document
|
| 14 |
+
|
| 15 |
+
class DocumentChunker:
|
| 16 |
+
"""
|
| 17 |
+
A class to read various document types and chunk them using LangChain
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
|
| 21 |
+
"""
|
| 22 |
+
Initialize the DocumentChunker
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
chunk_size (int): Size of each chunk in characters
|
| 26 |
+
chunk_overlap (int): Number of characters to overlap between chunks
|
| 27 |
+
"""
|
| 28 |
+
self.chunk_size = chunk_size
|
| 29 |
+
self.chunk_overlap = chunk_overlap
|
| 30 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 31 |
+
chunk_size=chunk_size,
|
| 32 |
+
chunk_overlap=chunk_overlap,
|
| 33 |
+
length_function=len,
|
| 34 |
+
separators=["\n\n", "\n", " ", ""]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def read_pdf(self, file_path: str) -> List[Document]:
|
| 38 |
+
"""Read PDF file and return documents"""
|
| 39 |
+
try:
|
| 40 |
+
loader = PyPDFLoader(file_path)
|
| 41 |
+
documents = loader.load()
|
| 42 |
+
return documents
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Error reading PDF file {file_path}: {e}")
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
def read_docx(self, file_path: str) -> List[Document]:
|
| 48 |
+
"""Read DOCX file and return documents"""
|
| 49 |
+
try:
|
| 50 |
+
loader = Docx2txtLoader(file_path)
|
| 51 |
+
documents = loader.load()
|
| 52 |
+
return documents
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error reading DOCX file {file_path}: {e}")
|
| 55 |
+
return []
|
| 56 |
+
|
| 57 |
+
def read_txt(self, file_path: str) -> List[Document]:
|
| 58 |
+
"""Read TXT file and return documents"""
|
| 59 |
+
try:
|
| 60 |
+
loader = TextLoader(file_path, encoding='utf-8')
|
| 61 |
+
documents = loader.load()
|
| 62 |
+
return documents
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error reading TXT file {file_path}: {e}")
|
| 65 |
+
return []
|
| 66 |
+
|
| 67 |
+
def read_md(self, file_path: str) -> List[Document]:
|
| 68 |
+
"""Read Markdown file and return documents"""
|
| 69 |
+
try:
|
| 70 |
+
loader = UnstructuredMarkdownLoader(file_path)
|
| 71 |
+
documents = loader.load()
|
| 72 |
+
return documents
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Error reading MD file {file_path}: {e}")
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def load_document(self, file_path: str) -> List[Document]:
|
| 78 |
+
"""
|
| 79 |
+
Load document based on file extension
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
file_path (str): Path to the document file
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
List[Document]: List of loaded documents
|
| 86 |
+
"""
|
| 87 |
+
file_extension = Path(file_path).suffix.lower()
|
| 88 |
+
|
| 89 |
+
if file_extension == '.pdf':
|
| 90 |
+
return self.read_pdf(file_path)
|
| 91 |
+
elif file_extension == '.docx':
|
| 92 |
+
return self.read_docx(file_path)
|
| 93 |
+
elif file_extension == '.txt':
|
| 94 |
+
return self.read_txt(file_path)
|
| 95 |
+
elif file_extension == '.md':
|
| 96 |
+
return self.read_md(file_path)
|
| 97 |
+
else:
|
| 98 |
+
print(f"Unsupported file type: {file_extension}")
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
def chunk_documents(self, documents: List[Document]) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
Chunk documents and return list of strings
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
documents (List[Document]): List of documents to chunk
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List[str]: List of chunked text strings
|
| 110 |
+
"""
|
| 111 |
+
if not documents:
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
# Split documents into chunks
|
| 115 |
+
chunks = self.text_splitter.split_documents(documents)
|
| 116 |
+
|
| 117 |
+
# Extract text content from chunks
|
| 118 |
+
chunk_texts = [chunk.page_content for chunk in chunks]
|
| 119 |
+
|
| 120 |
+
return chunk_texts
|
| 121 |
+
|
| 122 |
+
def process_file(self, file_path: str) -> List[str]:
|
| 123 |
+
"""
|
| 124 |
+
Process a single file: load and chunk it
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
file_path (str): Path to the file to process
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List[str]: List of chunked text strings
|
| 131 |
+
"""
|
| 132 |
+
if not os.path.exists(file_path):
|
| 133 |
+
print(f"File not found: {file_path}")
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
# Load document
|
| 137 |
+
documents = self.load_document(file_path)
|
| 138 |
+
|
| 139 |
+
if not documents:
|
| 140 |
+
print(f"No content loaded from {file_path}")
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# Chunk documents
|
| 144 |
+
chunks = self.chunk_documents(documents)
|
| 145 |
+
|
| 146 |
+
print(f"Successfully processed {file_path}: {len(chunks)} chunks created")
|
| 147 |
+
return chunks
|
| 148 |
+
|
| 149 |
+
def process_multiple_files(self, file_paths: List[str]) -> List[str]:
|
| 150 |
+
"""
|
| 151 |
+
Process multiple files and return combined chunks
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
file_paths (List[str]): List of file paths to process
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List[str]: Combined list of chunked text strings
|
| 158 |
+
"""
|
| 159 |
+
all_chunks = []
|
| 160 |
+
|
| 161 |
+
for file_path in file_paths:
|
| 162 |
+
chunks = self.process_file(file_path)
|
| 163 |
+
all_chunks.extend(chunks)
|
| 164 |
+
|
| 165 |
+
return all_chunks
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# Example usage and utility functions
|
| 169 |
+
def main():
|
| 170 |
+
"""Example usage of the DocumentChunker class"""
|
| 171 |
+
|
| 172 |
+
# Initialize chunker with custom parameters
|
| 173 |
+
chunker = DocumentChunker(chunk_size=800, chunk_overlap=100)
|
| 174 |
+
|
| 175 |
+
# Example: Process a single file
|
| 176 |
+
file_path = "example.pdf" # Replace with your file path
|
| 177 |
+
chunks = chunker.process_file(file_path)
|
| 178 |
+
|
| 179 |
+
if chunks:
|
| 180 |
+
print(f"Total chunks: {len(chunks)}")
|
| 181 |
+
print("\nFirst chunk preview:")
|
| 182 |
+
print(chunks[0][:200] + "..." if len(chunks[0]) > 200 else chunks[0])
|
| 183 |
+
|
| 184 |
+
# Example: Process multiple files
|
| 185 |
+
file_paths = [
|
| 186 |
+
"document1.pdf",
|
| 187 |
+
"document2.docx",
|
| 188 |
+
"document3.txt",
|
| 189 |
+
"document4.md"
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
all_chunks = chunker.process_multiple_files(file_paths)
|
| 193 |
+
print(f"\nTotal chunks from all files: {len(all_chunks)}")
|
| 194 |
+
|
| 195 |
+
return all_chunks
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def create_chunker_with_custom_settings(chunk_size: int = 1000,
|
| 199 |
+
chunk_overlap: int = 200) -> DocumentChunker:
|
| 200 |
+
"""
|
| 201 |
+
Create a DocumentChunker with custom settings
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
chunk_size (int): Size of each chunk
|
| 205 |
+
chunk_overlap (int): Overlap between chunks
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
DocumentChunker: Configured chunker instance
|
| 209 |
+
"""
|
| 210 |
+
return DocumentChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
main()
|
src/agenticRAG/components/embeddings.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from langchain_openai import OpenAIEmbeddings
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
from typing import Union, Literal
|
| 5 |
+
|
| 6 |
+
class EmbeddingFactory:
|
| 7 |
+
"""Factory for creating embedding instances"""
|
| 8 |
+
|
| 9 |
+
_huggingface_instance = None
|
| 10 |
+
_openai_instance = None
|
| 11 |
+
|
| 12 |
+
@classmethod
|
| 13 |
+
def get_embeddings(cls, provider: Literal["huggingface", "openai"] = "huggingface") -> Union[HuggingFaceEmbeddings, OpenAIEmbeddings]:
|
| 14 |
+
"""Get or create embeddings instance (singleton pattern)"""
|
| 15 |
+
if provider == "huggingface":
|
| 16 |
+
if cls._huggingface_instance is None:
|
| 17 |
+
cls._huggingface_instance = HuggingFaceEmbeddings(
|
| 18 |
+
model_name=settings.EMBEDDING_MODEL
|
| 19 |
+
)
|
| 20 |
+
return cls._huggingface_instance
|
| 21 |
+
elif provider == "openai":
|
| 22 |
+
if cls._openai_instance is None:
|
| 23 |
+
cls._openai_instance = OpenAIEmbeddings(
|
| 24 |
+
model=settings.OPENAI_EMBEDDING_MODEL,
|
| 25 |
+
openai_api_key=settings.OPENAI_API_KEY
|
| 26 |
+
)
|
| 27 |
+
return cls._openai_instance
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def create_new_embeddings(cls, provider: Literal["huggingface", "openai"] = "huggingface", **kwargs) -> Union[HuggingFaceEmbeddings, OpenAIEmbeddings]:
|
| 33 |
+
"""Create a new embeddings instance with custom parameters"""
|
| 34 |
+
if provider == "huggingface":
|
| 35 |
+
return HuggingFaceEmbeddings(
|
| 36 |
+
model_name=kwargs.get("model_name", settings.EMBEDDING_MODEL),
|
| 37 |
+
**{k: v for k, v in kwargs.items() if k != "model_name"}
|
| 38 |
+
)
|
| 39 |
+
elif provider == "openai":
|
| 40 |
+
return OpenAIEmbeddings(
|
| 41 |
+
model=kwargs.get("model", settings.OPENAI_EMBEDDING_MODEL),
|
| 42 |
+
openai_api_key=kwargs.get("api_key", settings.OPENAI_API_KEY),
|
| 43 |
+
**{k: v for k, v in kwargs.items() if k not in ["model", "api_key"]}
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def get_huggingface_embeddings(cls) -> HuggingFaceEmbeddings:
|
| 50 |
+
"""Convenience method to get HuggingFace embeddings"""
|
| 51 |
+
return cls.get_embeddings("huggingface")
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def get_openai_embeddings(cls) -> OpenAIEmbeddings:
|
| 55 |
+
"""Convenience method to get OpenAI embeddings"""
|
| 56 |
+
return cls.get_embeddings("openai")
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def reset_instances(cls):
|
| 60 |
+
"""Reset singleton instances (useful for testing)"""
|
| 61 |
+
cls._huggingface_instance = None
|
| 62 |
+
cls._openai_instance = None
|
src/agenticRAG/components/llm_factory.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_groq import ChatGroq
|
| 2 |
+
from langchain_openai import ChatOpenAI
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
from typing import Union, Literal
|
| 5 |
+
|
| 6 |
+
class LLMFactory:
|
| 7 |
+
"""Factory for creating LLM instances"""
|
| 8 |
+
|
| 9 |
+
_groq_instance = None
|
| 10 |
+
_openai_instance = None
|
| 11 |
+
|
| 12 |
+
@classmethod
|
| 13 |
+
def get_llm(cls, provider: Literal["groq", "openai"] = "groq") -> Union[ChatGroq, ChatOpenAI]:
|
| 14 |
+
"""Get or create LLM instance (singleton pattern)"""
|
| 15 |
+
if provider == "groq":
|
| 16 |
+
if cls._groq_instance is None:
|
| 17 |
+
cls._groq_instance = ChatGroq(
|
| 18 |
+
model=settings.GROQ_MODEL,
|
| 19 |
+
temperature=settings.GROQ_TEMPERATURE,
|
| 20 |
+
groq_api_key=settings.GROQ_API_KEY
|
| 21 |
+
)
|
| 22 |
+
return cls._groq_instance
|
| 23 |
+
elif provider == "openai":
|
| 24 |
+
if cls._openai_instance is None:
|
| 25 |
+
cls._openai_instance = ChatOpenAI(
|
| 26 |
+
model=settings.OPENAI_MODEL,
|
| 27 |
+
temperature=settings.OPENAI_TEMPERATURE,
|
| 28 |
+
openai_api_key=settings.OPENAI_API_KEY
|
| 29 |
+
)
|
| 30 |
+
return cls._openai_instance
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def create_new_llm(cls, provider: Literal["groq", "openai"] = "groq", **kwargs) -> Union[ChatGroq, ChatOpenAI]:
|
| 36 |
+
"""Create a new LLM instance with custom parameters"""
|
| 37 |
+
if provider == "groq":
|
| 38 |
+
return ChatGroq(
|
| 39 |
+
model=kwargs.get("model", settings.GROQ_MODEL),
|
| 40 |
+
temperature=kwargs.get("temperature", settings.GROQ_TEMPERATURE),
|
| 41 |
+
groq_api_key=kwargs.get("api_key", settings.GROQ_API_KEY)
|
| 42 |
+
)
|
| 43 |
+
elif provider == "openai":
|
| 44 |
+
return ChatOpenAI(
|
| 45 |
+
model=kwargs.get("model", settings.OPENAI_MODEL),
|
| 46 |
+
temperature=kwargs.get("temperature", settings.OPENAI_TEMPERATURE),
|
| 47 |
+
openai_api_key=kwargs.get("api_key", settings.OPENAI_API_KEY)
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def get_groq_llm(cls) -> ChatGroq:
|
| 54 |
+
"""Convenience method to get Groq LLM"""
|
| 55 |
+
return cls.get_llm("groq")
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def get_openai_llm(cls) -> ChatOpenAI:
|
| 59 |
+
"""Convenience method to get OpenAI LLM"""
|
| 60 |
+
return cls.get_llm("openai")
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def reset_instances(cls):
|
| 64 |
+
"""Reset singleton instances (useful for testing)"""
|
| 65 |
+
cls._groq_instance = None
|
| 66 |
+
cls._openai_instance = None
|
src/agenticRAG/components/search_tools.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_tavily import TavilySearch
|
| 2 |
+
from langchain_community.utilities import GoogleSerperAPIWrapper
|
| 3 |
+
from langchain_community.tools import GoogleSerperRun
|
| 4 |
+
from src.config.settings import settings
|
| 5 |
+
from typing import Union, Literal
|
| 6 |
+
|
| 7 |
+
class SearchToolFactory:
|
| 8 |
+
"""Factory for creating search tools"""
|
| 9 |
+
|
| 10 |
+
_tavily_instance = None
|
| 11 |
+
_serper_instance = None
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def get_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily") -> Union[TavilySearch, GoogleSerperRun]:
|
| 15 |
+
"""Get or create search tool instance (singleton pattern)"""
|
| 16 |
+
if provider == "tavily":
|
| 17 |
+
if cls._tavily_instance is None:
|
| 18 |
+
cls._tavily_instance = TavilySearch(
|
| 19 |
+
api_key=settings.TAVILY_API_KEY
|
| 20 |
+
)
|
| 21 |
+
return cls._tavily_instance
|
| 22 |
+
elif provider == "serper":
|
| 23 |
+
if cls._serper_instance is None:
|
| 24 |
+
search_wrapper = GoogleSerperAPIWrapper(
|
| 25 |
+
serper_api_key=settings.SERPER_API_KEY
|
| 26 |
+
)
|
| 27 |
+
cls._serper_instance = GoogleSerperRun(api_wrapper=search_wrapper)
|
| 28 |
+
return cls._serper_instance
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def create_new_search_tool(cls, provider: Literal["tavily", "serper"] = "tavily", **kwargs) -> Union[TavilySearch, GoogleSerperRun]:
|
| 34 |
+
"""Create a new search tool instance with custom parameters"""
|
| 35 |
+
if provider == "tavily":
|
| 36 |
+
return TavilySearch(
|
| 37 |
+
api_key=kwargs.get("api_key", settings.TAVILY_API_KEY),
|
| 38 |
+
max_results=kwargs.get("max_results", settings.SEARCH_RESULTS_COUNT),
|
| 39 |
+
search_depth=kwargs.get("search_depth", settings.TAVILY_SEARCH_DEPTH),
|
| 40 |
+
include_answer=kwargs.get("include_answer", settings.TAVILY_INCLUDE_ANSWER),
|
| 41 |
+
include_raw_content=kwargs.get("include_raw_content", settings.TAVILY_INCLUDE_RAW_CONTENT),
|
| 42 |
+
**{k: v for k, v in kwargs.items() if k not in ["api_key", "max_results", "search_depth", "include_answer", "include_raw_content"]}
|
| 43 |
+
)
|
| 44 |
+
elif provider == "serper":
|
| 45 |
+
search_wrapper = GoogleSerperAPIWrapper(
|
| 46 |
+
serper_api_key=kwargs.get("api_key", settings.SERPER_API_KEY),
|
| 47 |
+
k=kwargs.get("k", settings.SEARCH_RESULTS_COUNT),
|
| 48 |
+
type=kwargs.get("type", settings.SERPER_SEARCH_TYPE),
|
| 49 |
+
country=kwargs.get("country", settings.SERPER_COUNTRY),
|
| 50 |
+
location=kwargs.get("location", settings.SERPER_LOCATION),
|
| 51 |
+
**{k: v for k, v in kwargs.items() if k not in ["api_key", "k", "type", "country", "location"]}
|
| 52 |
+
)
|
| 53 |
+
return GoogleSerperRun(api_wrapper=search_wrapper)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def get_tavily_search(cls) -> TavilySearch:
|
| 59 |
+
"""Convenience method to get Tavily search tool"""
|
| 60 |
+
return cls.get_search_tool("tavily")
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def get_serper_search(cls) -> GoogleSerperRun:
|
| 64 |
+
"""Convenience method to get Serper search tool"""
|
| 65 |
+
return cls.get_search_tool("serper")
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def reset_instances(cls):
|
| 69 |
+
"""Reset singleton instances (useful for testing)"""
|
| 70 |
+
cls._tavily_instance = None
|
| 71 |
+
cls._serper_instance = None
|
src/agenticRAG/components/vectorstore.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from src.config.settings import settings
|
| 5 |
+
from src.agenticRAG.components.embeddings import EmbeddingFactory
|
| 6 |
+
import os
|
| 7 |
+
from typing import Dict, Any, List, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from src.agenticRAG.components.document_parsing import DocumentChunker
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class VectorStoreManager:
|
| 13 |
+
"""Manager for vector store operations"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.embeddings = EmbeddingFactory.get_embeddings()
|
| 17 |
+
self.vectorstore = None
|
| 18 |
+
|
| 19 |
+
def load_vectorstore(self, path: Optional[str] = None) -> bool:
|
| 20 |
+
"""Load vector store from path"""
|
| 21 |
+
try:
|
| 22 |
+
path = path or settings.VECTORSTORE_PATH
|
| 23 |
+
if os.path.exists(path):
|
| 24 |
+
self.vectorstore = FAISS.load_local(path, self.embeddings, allow_dangerous_deserialization=True)
|
| 25 |
+
return True
|
| 26 |
+
return False
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading vectorstore: {e}")
|
| 29 |
+
return False
|
| 30 |
+
|
| 31 |
+
def search_documents(self, query: str, k: int = 3) -> List[str]:
|
| 32 |
+
"""Search for similar documents"""
|
| 33 |
+
if not self.vectorstore:
|
| 34 |
+
return []
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
docs = self.vectorstore.similarity_search(query, k=k)
|
| 38 |
+
return [doc.page_content for doc in docs]
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Error searching documents: {e}")
|
| 41 |
+
return []
|
| 42 |
+
|
| 43 |
+
def add_documents(self, texts: List[str], metadatas: Optional[List[dict]] = None):
|
| 44 |
+
"""Add documents to vector store"""
|
| 45 |
+
if not self.vectorstore:
|
| 46 |
+
self.vectorstore = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
|
| 47 |
+
else:
|
| 48 |
+
self.vectorstore.add_texts(texts, metadatas=metadatas)
|
| 49 |
+
|
| 50 |
+
def save_vectorstore(self, path: Optional[str] = None):
|
| 51 |
+
"""Save vector store to path"""
|
| 52 |
+
if self.vectorstore:
|
| 53 |
+
path = path or settings.VECTORSTORE_PATH
|
| 54 |
+
self.vectorstore.save_local(path)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def store_documents_in_vectorstore(
|
| 60 |
+
file_paths: List[str],
|
| 61 |
+
vectorstore_manager: Optional[VectorStoreManager] = None,
|
| 62 |
+
chunk_size: int = 1000,
|
| 63 |
+
chunk_overlap: int = 200,
|
| 64 |
+
save_path: Optional[str] = None,
|
| 65 |
+
include_metadata: bool = True
|
| 66 |
+
) -> Dict[str, Any]:
|
| 67 |
+
"""
|
| 68 |
+
Process documents and store them in vector store
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
file_paths (List[str]): List of file paths to process
|
| 72 |
+
vectorstore_manager (VectorStoreManager, optional): Existing manager instance
|
| 73 |
+
chunk_size (int): Size of each chunk
|
| 74 |
+
chunk_overlap (int): Overlap between chunks
|
| 75 |
+
save_path (str, optional): Path to save the vector store
|
| 76 |
+
include_metadata (bool): Whether to include file metadata
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Dict[str, Any]: Processing results with statistics
|
| 80 |
+
"""
|
| 81 |
+
# Initialize components
|
| 82 |
+
if vectorstore_manager is None:
|
| 83 |
+
vectorstore_manager = VectorStoreManager()
|
| 84 |
+
|
| 85 |
+
chunker = DocumentChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 86 |
+
|
| 87 |
+
# Load existing vectorstore if available
|
| 88 |
+
vectorstore_manager.load_vectorstore(save_path)
|
| 89 |
+
|
| 90 |
+
# Track processing statistics
|
| 91 |
+
results = {
|
| 92 |
+
"total_files": len(file_paths),
|
| 93 |
+
"processed_files": 0,
|
| 94 |
+
"failed_files": [],
|
| 95 |
+
"total_chunks": 0,
|
| 96 |
+
"chunks_by_file": {}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
for file_path in file_paths:
|
| 101 |
+
try:
|
| 102 |
+
print(f"Processing file: {file_path}")
|
| 103 |
+
|
| 104 |
+
# Process file into chunks
|
| 105 |
+
chunks = chunker.process_file(file_path)
|
| 106 |
+
|
| 107 |
+
if chunks:
|
| 108 |
+
# Prepare metadata if requested
|
| 109 |
+
metadatas = None
|
| 110 |
+
if include_metadata:
|
| 111 |
+
file_name = Path(file_path).name
|
| 112 |
+
file_extension = Path(file_path).suffix
|
| 113 |
+
metadatas = [
|
| 114 |
+
{
|
| 115 |
+
"source": file_path,
|
| 116 |
+
"file_name": file_name,
|
| 117 |
+
"file_extension": file_extension,
|
| 118 |
+
"chunk_index": i
|
| 119 |
+
}
|
| 120 |
+
for i in range(len(chunks))
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
# Add documents to vector store
|
| 124 |
+
vectorstore_manager.add_documents(chunks, metadatas)
|
| 125 |
+
|
| 126 |
+
# Update statistics
|
| 127 |
+
results["processed_files"] += 1
|
| 128 |
+
results["total_chunks"] += len(chunks)
|
| 129 |
+
results["chunks_by_file"][file_path] = len(chunks)
|
| 130 |
+
|
| 131 |
+
print(f"Successfully processed {file_path}: {len(chunks)} chunks")
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
print(f"No chunks extracted from {file_path}")
|
| 135 |
+
results["failed_files"].append(file_path)
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"Error processing file {file_path}: {e}")
|
| 139 |
+
results["failed_files"].append(file_path)
|
| 140 |
+
|
| 141 |
+
# Save the vector store
|
| 142 |
+
if results["total_chunks"] > 0:
|
| 143 |
+
vectorstore_manager.save_vectorstore(save_path)
|
| 144 |
+
print(f"Vector store saved with {results['total_chunks']} total chunks")
|
| 145 |
+
|
| 146 |
+
return results
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Error in store_documents_in_vectorstore: {e}")
|
| 150 |
+
results["error"] = str(e)
|
| 151 |
+
return results
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def store_single_document_in_vectorstore(
|
| 155 |
+
file_path: str,
|
| 156 |
+
vectorstore_manager: Optional[VectorStoreManager] = None,
|
| 157 |
+
chunk_size: int = 1000,
|
| 158 |
+
chunk_overlap: int = 200,
|
| 159 |
+
save_path: Optional[str] = None
|
| 160 |
+
) -> bool:
|
| 161 |
+
"""
|
| 162 |
+
Process and store a single document in vector store
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
file_path (str): Path to the file to process
|
| 166 |
+
vectorstore_manager (VectorStoreManager, optional): Existing manager instance
|
| 167 |
+
chunk_size (int): Size of each chunk
|
| 168 |
+
chunk_overlap (int): Overlap between chunks
|
| 169 |
+
save_path (str, optional): Path to save the vector store
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
bool: Success status
|
| 173 |
+
"""
|
| 174 |
+
results = store_documents_in_vectorstore(
|
| 175 |
+
file_paths=[file_path],
|
| 176 |
+
vectorstore_manager=vectorstore_manager,
|
| 177 |
+
chunk_size=chunk_size,
|
| 178 |
+
chunk_overlap=chunk_overlap,
|
| 179 |
+
save_path=save_path
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return results["processed_files"] > 0
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def batch_store_documents(
|
| 186 |
+
directory_path: str,
|
| 187 |
+
file_extensions: List[str] = [".pdf", ".docx", ".txt", ".md"],
|
| 188 |
+
vectorstore_manager: Optional[VectorStoreManager] = None,
|
| 189 |
+
chunk_size: int = 1000,
|
| 190 |
+
chunk_overlap: int = 200,
|
| 191 |
+
save_path: Optional[str] = None
|
| 192 |
+
) -> Dict[str, Any]:
|
| 193 |
+
"""
|
| 194 |
+
Process and store all documents from a directory
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
directory_path (str): Path to directory containing documents
|
| 198 |
+
file_extensions (List[str]): List of file extensions to process
|
| 199 |
+
vectorstore_manager (VectorStoreManager, optional): Existing manager instance
|
| 200 |
+
chunk_size (int): Size of each chunk
|
| 201 |
+
chunk_overlap (int): Overlap between chunks
|
| 202 |
+
save_path (str, optional): Path to save the vector store
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Dict[str, Any]: Processing results
|
| 206 |
+
"""
|
| 207 |
+
# Find all files with specified extensions
|
| 208 |
+
directory = Path(directory_path)
|
| 209 |
+
file_paths = []
|
| 210 |
+
|
| 211 |
+
for extension in file_extensions:
|
| 212 |
+
file_paths.extend(directory.glob(f"*{extension}"))
|
| 213 |
+
|
| 214 |
+
# Convert to string paths
|
| 215 |
+
file_paths = [str(path) for path in file_paths]
|
| 216 |
+
|
| 217 |
+
if not file_paths:
|
| 218 |
+
print(f"No files found in {directory_path} with extensions {file_extensions}")
|
| 219 |
+
return {"total_files": 0, "processed_files": 0, "failed_files": [], "total_chunks": 0}
|
| 220 |
+
|
| 221 |
+
print(f"Found {len(file_paths)} files to process")
|
| 222 |
+
|
| 223 |
+
return store_documents_in_vectorstore(
|
| 224 |
+
file_paths=file_paths,
|
| 225 |
+
vectorstore_manager=vectorstore_manager,
|
| 226 |
+
chunk_size=chunk_size,
|
| 227 |
+
chunk_overlap=chunk_overlap,
|
| 228 |
+
save_path=save_path
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# Example usage
|
| 233 |
+
def main():
|
| 234 |
+
"""Example usage of the vector store functions"""
|
| 235 |
+
|
| 236 |
+
# Initialize vector store manager
|
| 237 |
+
vs_manager = VectorStoreManager()
|
| 238 |
+
|
| 239 |
+
# Example 1: Store a single document
|
| 240 |
+
print("=== Storing Single Document ===")
|
| 241 |
+
file_path = "/home/ubuntu/OMANI-Therapist-Voice-ChatBot/KnowledgebaseFile/SuicideGuard_An_NLP-Based_Chrome_Extension_for_Detecting_Suicidal_Thoughts_in_Bengali.pdf"
|
| 242 |
+
success = store_single_document_in_vectorstore(
|
| 243 |
+
file_path=file_path,
|
| 244 |
+
vectorstore_manager=vs_manager,
|
| 245 |
+
chunk_size=1000,
|
| 246 |
+
chunk_overlap=150
|
| 247 |
+
)
|
| 248 |
+
print(f"Single document processing: {'Success' if success else 'Failed'}")
|
| 249 |
+
|
| 250 |
+
# # Example 2: Store multiple documents
|
| 251 |
+
# print("\n=== Storing Multiple Documents ===")
|
| 252 |
+
# file_paths = [
|
| 253 |
+
# "document1.pdf",
|
| 254 |
+
# "document2.docx",
|
| 255 |
+
# "document3.txt"
|
| 256 |
+
# ]
|
| 257 |
+
|
| 258 |
+
# results = store_documents_in_vectorstore(
|
| 259 |
+
# file_paths=file_paths,
|
| 260 |
+
# vectorstore_manager=vs_manager,
|
| 261 |
+
# chunk_size=1000,
|
| 262 |
+
# chunk_overlap=200
|
| 263 |
+
# )
|
| 264 |
+
|
| 265 |
+
# print(f"Processing Results:")
|
| 266 |
+
# print(f" Total files: {results['total_files']}")
|
| 267 |
+
# print(f" Processed files: {results['processed_files']}")
|
| 268 |
+
# print(f" Failed files: {results['failed_files']}")
|
| 269 |
+
# print(f" Total chunks: {results['total_chunks']}")
|
| 270 |
+
|
| 271 |
+
# # Example 3: Batch process directory
|
| 272 |
+
# print("\n=== Batch Processing Directory ===")
|
| 273 |
+
# directory_path = "/home/ubuntu/OMANI-Therapist-Voice-ChatBot/KnowledgebaseFile/"
|
| 274 |
+
|
| 275 |
+
# batch_results = batch_store_documents(
|
| 276 |
+
# directory_path=directory_path,
|
| 277 |
+
# file_extensions=[".pdf", ".docx", ".txt", ".md"],
|
| 278 |
+
# vectorstore_manager=vs_manager
|
| 279 |
+
# )
|
| 280 |
+
|
| 281 |
+
# print(f"Batch Processing Results:")
|
| 282 |
+
# print(f" Total files: {batch_results['total_files']}")
|
| 283 |
+
# print(f" Processed files: {batch_results['processed_files']}")
|
| 284 |
+
# print(f" Total chunks: {batch_results['total_chunks']}")
|
| 285 |
+
|
| 286 |
+
# Example 4: Search the vector store
|
| 287 |
+
print("\n=== Searching Vector Store ===")
|
| 288 |
+
query = "suicide prevention techniques"
|
| 289 |
+
search_results = vs_manager.search_documents(query, k=3)
|
| 290 |
+
|
| 291 |
+
print(f"Search results for '{query}':")
|
| 292 |
+
for i, result in enumerate(search_results):
|
| 293 |
+
print(f" Result {i+1}: {result[:200]}...")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
main()
|
src/agenticRAG/gpt.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import os
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from loguru import logger
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
class EmotionalState(Enum):
|
| 13 |
+
CALM = "calm"
|
| 14 |
+
ANXIOUS = "anxious"
|
| 15 |
+
DEPRESSED = "depressed"
|
| 16 |
+
ANGRY = "angry"
|
| 17 |
+
DISTRESSED = "distressed"
|
| 18 |
+
|
| 19 |
+
class OmaniTherapistAI:
|
| 20 |
+
def __init__(self, api_key: str = None):
|
| 21 |
+
"""
|
| 22 |
+
Initialize the OMANI Therapist AI system
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
api_key: OpenAI API key (if not provided, will use environment variable)
|
| 26 |
+
"""
|
| 27 |
+
self.api_key = api_key or os.getenv('OPENAI_API_KEY')
|
| 28 |
+
if not self.api_key:
|
| 29 |
+
raise ValueError("OpenAI API key is required")
|
| 30 |
+
|
| 31 |
+
self.client = OpenAI(api_key=self.api_key)
|
| 32 |
+
|
| 33 |
+
# Session management
|
| 34 |
+
self.conversation_history = []
|
| 35 |
+
self.user_profile = {}
|
| 36 |
+
self.emotional_state = EmotionalState.CALM
|
| 37 |
+
|
| 38 |
+
# System prompt for therapeutic conversations
|
| 39 |
+
self.system_prompt = self._create_system_prompt()
|
| 40 |
+
|
| 41 |
+
def _create_system_prompt(self) -> str:
|
| 42 |
+
"""Create comprehensive system prompt for bilingual therapeutic conversations"""
|
| 43 |
+
return """You are a specialized mental health counselor for the Omani community. You are fluent in both Arabic (Omani dialect) and English, and you understand Gulf culture and Islamic values deeply.
|
| 44 |
+
|
| 45 |
+
## Your Identity & Characteristics:
|
| 46 |
+
- Omani Mental Health Counselor
|
| 47 |
+
- Bilingual: Fluent in Omani Arabic and English
|
| 48 |
+
- Culturally competent in Gulf and Islamic traditions
|
| 49 |
+
- Understand family dynamics and Gulf society
|
| 50 |
+
- Integrate Islamic concepts in therapy when appropriate
|
| 51 |
+
- Handle code-switching naturally between Arabic and English
|
| 52 |
+
|
| 53 |
+
## Your Therapeutic Skills:
|
| 54 |
+
- Cognitive Behavioral Therapy (CBT) adapted for Omani culture
|
| 55 |
+
- Active listening and empathy
|
| 56 |
+
- Anxiety and stress management techniques
|
| 57 |
+
- Family and relationship therapy
|
| 58 |
+
- Trauma-informed approaches
|
| 59 |
+
- Spiritual therapy compatible with Islam
|
| 60 |
+
|
| 61 |
+
## Language Guidelines:
|
| 62 |
+
**CRITICAL: Always respond in the SAME language the user uses:**
|
| 63 |
+
- If user writes in Arabic → respond in Omani Arabic
|
| 64 |
+
- If user writes in English → respond in English
|
| 65 |
+
- If user mixes languages → mirror their code-switching pattern
|
| 66 |
+
- Maintain cultural sensitivity in both languages
|
| 67 |
+
|
| 68 |
+
## Response Instructions:
|
| 69 |
+
- Start with warm greeting and check emotional state
|
| 70 |
+
- Ask open-ended questions to understand situation
|
| 71 |
+
- Use reframing and summarization techniques
|
| 72 |
+
- Offer practical coping strategies
|
| 73 |
+
- End with summary and follow-up suggestions
|
| 74 |
+
- Keep responses 100-200 words
|
| 75 |
+
- Show empathy and understanding
|
| 76 |
+
|
| 77 |
+
## Cultural Sensitivity:
|
| 78 |
+
- Respect Islamic values and Omani traditions
|
| 79 |
+
- Avoid taboo or controversial topics
|
| 80 |
+
- Consider family/community role in mental health
|
| 81 |
+
- Use religious references wisely when appropriate
|
| 82 |
+
- Address mental health stigma sensitively
|
| 83 |
+
|
| 84 |
+
Remember: You are a supportive assistant, not a replacement for professional specialized therapy.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def detect_language(self, text: str) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Detect if text is primarily Arabic or English
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
text: Input text to analyze
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
'arabic', 'english', or 'mixed'
|
| 96 |
+
"""
|
| 97 |
+
# Count Arabic vs English characters
|
| 98 |
+
arabic_chars = sum(1 for char in text if '\u0600' <= char <= '\u06FF')
|
| 99 |
+
english_chars = sum(1 for char in text if char.isalpha() and char.isascii())
|
| 100 |
+
|
| 101 |
+
if arabic_chars > english_chars:
|
| 102 |
+
return 'arabic'
|
| 103 |
+
elif english_chars > arabic_chars:
|
| 104 |
+
return 'english'
|
| 105 |
+
else:
|
| 106 |
+
return 'mixed'
|
| 107 |
+
|
| 108 |
+
def analyze_emotional_state(self, user_input: str) -> Tuple[EmotionalState, str]:
|
| 109 |
+
"""
|
| 110 |
+
Analyze user's emotional state from input
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
user_input: User's message in Arabic or English
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Tuple of (emotional_state, detected_language)
|
| 117 |
+
"""
|
| 118 |
+
user_input_lower = user_input.lower()
|
| 119 |
+
detected_language = self.detect_language(user_input)
|
| 120 |
+
|
| 121 |
+
# Emotional state analysis using keywords (expanded for both languages)
|
| 122 |
+
anxiety_keywords = [
|
| 123 |
+
# Arabic
|
| 124 |
+
'قلق', 'خوف', 'توتر', 'قلقان', 'مضطرب', 'خايف', 'متوتر', 'مهموم',
|
| 125 |
+
'أشعر بالقلق', 'أخاف', 'عندي قلق', 'مش مرتاح', 'مو مرتاح',
|
| 126 |
+
# English
|
| 127 |
+
'anxiety', 'worried', 'nervous', 'anxious', 'panic', 'scared', 'fearful',
|
| 128 |
+
'feel anxious', 'feeling worried', 'i\'m scared', 'i\'m nervous'
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
depression_keywords = [
|
| 132 |
+
# Arabic
|
| 133 |
+
'حزن', 'اكتئاب', 'مكتئب', 'حزين', 'يائس', 'زعلان', 'مش راضي',
|
| 134 |
+
'أشعر بالحزن', 'مو مبسوط', 'تعبان نفسياً', 'مش عارف شنو أسوي',
|
| 135 |
+
# English
|
| 136 |
+
'depressed', 'sad', 'hopeless', 'down', 'blue', 'miserable', 'unhappy',
|
| 137 |
+
'feeling down', 'feel sad', 'i\'m depressed', 'feeling hopeless'
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
anger_keywords = [
|
| 141 |
+
# Arabic
|
| 142 |
+
'غضب', 'غاضب', 'زعلان', 'مستاء', 'عصبي', 'متضايق', 'مش راضي',
|
| 143 |
+
'أشعر بالغضب', 'مزعوج', 'معصب', 'متنرفز',
|
| 144 |
+
# English
|
| 145 |
+
'angry', 'mad', 'frustrated', 'irritated', 'annoyed', 'upset', 'furious',
|
| 146 |
+
'feel angry', 'i\'m mad', 'feeling frustrated', 'really upset'
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
stress_keywords = [
|
| 150 |
+
# Arabic
|
| 151 |
+
'ضغط', 'ضغوط', 'تعب', 'مرهق', 'تعبان', 'مش قادر', 'صعب عليّ',
|
| 152 |
+
'أشعر بالضغط', 'مرهق نفسياً', 'ما أقدر أكمل',
|
| 153 |
+
# English
|
| 154 |
+
'stress', 'stressed', 'pressure', 'overwhelmed', 'exhausted', 'burned out',
|
| 155 |
+
'feeling stressed', 'under pressure', 'can\'t cope', 'too much pressure'
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
if any(keyword in user_input_lower for keyword in anxiety_keywords):
|
| 159 |
+
return EmotionalState.ANXIOUS, detected_language
|
| 160 |
+
elif any(keyword in user_input_lower for keyword in depression_keywords):
|
| 161 |
+
return EmotionalState.DEPRESSED, detected_language
|
| 162 |
+
elif any(keyword in user_input_lower for keyword in anger_keywords):
|
| 163 |
+
return EmotionalState.ANGRY, detected_language
|
| 164 |
+
elif any(keyword in user_input_lower for keyword in stress_keywords):
|
| 165 |
+
return EmotionalState.DISTRESSED, detected_language
|
| 166 |
+
|
| 167 |
+
return EmotionalState.CALM, detected_language
|
| 168 |
+
|
| 169 |
+
def generate_therapeutic_response(self, user_input: str, include_history: bool = True) -> Dict:
|
| 170 |
+
"""
|
| 171 |
+
Generate therapeutic response using OpenAI GPT-4o
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
user_input: User's message
|
| 175 |
+
include_history: Whether to include conversation history
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Dictionary containing response and metadata
|
| 179 |
+
"""
|
| 180 |
+
try:
|
| 181 |
+
# Analyze emotional state and detect language
|
| 182 |
+
emotional_state, detected_language = self.analyze_emotional_state(user_input)
|
| 183 |
+
self.emotional_state = emotional_state
|
| 184 |
+
|
| 185 |
+
# Prepare messages for API
|
| 186 |
+
messages = [{"role": "system", "content": self.system_prompt}]
|
| 187 |
+
|
| 188 |
+
# Add language context to system prompt
|
| 189 |
+
language_instruction = f"\n\nIMPORTANT: The user is communicating in {detected_language}. Please respond in the same language they used."
|
| 190 |
+
messages[0]["content"] += language_instruction
|
| 191 |
+
|
| 192 |
+
# Add conversation history if requested
|
| 193 |
+
if include_history and self.conversation_history:
|
| 194 |
+
messages.extend(self.conversation_history[-6:]) # Last 6 messages for context
|
| 195 |
+
|
| 196 |
+
# Add current user message
|
| 197 |
+
messages.append({"role": "user", "content": user_input})
|
| 198 |
+
|
| 199 |
+
# Generate response using OpenAI
|
| 200 |
+
response = self.client.responses.create(
|
| 201 |
+
model="gpt-4.1-nano-2025-04-14",
|
| 202 |
+
input=messages,
|
| 203 |
+
temperature=0.7,
|
| 204 |
+
)
|
| 205 |
+
logger.info(f"Generated response: {response.output_text}")
|
| 206 |
+
|
| 207 |
+
ai_response = (response.output_text)
|
| 208 |
+
|
| 209 |
+
# Update conversation history
|
| 210 |
+
self.conversation_history.append({"role": "user", "content": user_input})
|
| 211 |
+
self.conversation_history.append({"role": "assistant", "content": ai_response})
|
| 212 |
+
|
| 213 |
+
# Keep only last 10 messages to manage context length
|
| 214 |
+
if len(self.conversation_history) > 10:
|
| 215 |
+
self.conversation_history = self.conversation_history[-10:]
|
| 216 |
+
|
| 217 |
+
return {
|
| 218 |
+
"response": ai_response,
|
| 219 |
+
"emotional_state": emotional_state.value,
|
| 220 |
+
"detected_language": detected_language,
|
| 221 |
+
"timestamp": datetime.now().isoformat(),
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 226 |
+
|
| 227 |
+
# Error response in detected language
|
| 228 |
+
detected_language = self.detect_language(user_input)
|
| 229 |
+
|
| 230 |
+
if detected_language == 'english':
|
| 231 |
+
error_message = "Sorry, a technical error occurred. Please try again or contact a specialist."
|
| 232 |
+
else:
|
| 233 |
+
error_message = "آسف، حدث خطأ تقني. يرجى المحاولة مرة أخرى أو التواصل مع المختص."
|
| 234 |
+
|
| 235 |
+
return {
|
| 236 |
+
"response": error_message,
|
| 237 |
+
"emotional_state": "unknown",
|
| 238 |
+
"detected_language": detected_language,
|
| 239 |
+
"timestamp": datetime.now().isoformat(),
|
| 240 |
+
"error": str(e)
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
def get_conversation_summary(self) -> Dict:
|
| 244 |
+
"""Get summary of current conversation session"""
|
| 245 |
+
return {
|
| 246 |
+
"total_messages": len(self.conversation_history),
|
| 247 |
+
"current_emotional_state": self.emotional_state.value,
|
| 248 |
+
"session_start": self.conversation_history[0].get("timestamp") if self.conversation_history else None,
|
| 249 |
+
"last_interaction": datetime.now().isoformat()
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
def clear_conversation(self):
|
| 253 |
+
"""Clear conversation history and reset state"""
|
| 254 |
+
self.conversation_history = []
|
| 255 |
+
self.emotional_state = EmotionalState.CALM
|
| 256 |
+
logger.info("Conversation cleared")
|
| 257 |
+
|
| 258 |
+
def export_conversation(self, filename: str = None) -> str:
|
| 259 |
+
"""Export conversation to JSON file"""
|
| 260 |
+
if not filename:
|
| 261 |
+
filename = f"therapy_session_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
| 262 |
+
|
| 263 |
+
session_data = {
|
| 264 |
+
"session_metadata": self.get_conversation_summary(),
|
| 265 |
+
"conversation_history": self.conversation_history,
|
| 266 |
+
"export_timestamp": datetime.now().isoformat()
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 270 |
+
json.dump(session_data, f, ensure_ascii=False, indent=2)
|
| 271 |
+
|
| 272 |
+
return filename
|
| 273 |
+
|
| 274 |
+
# Helper function for easy integration
|
| 275 |
+
def get_therapy_response(user_input: str, api_key: str = None) -> Dict:
|
| 276 |
+
"""
|
| 277 |
+
Simple function to get therapeutic response
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
user_input: User's message
|
| 281 |
+
api_key: OpenAI API key
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Dictionary with response and metadata
|
| 285 |
+
"""
|
| 286 |
+
therapist = OmaniTherapistAI(api_key)
|
| 287 |
+
return therapist.generate_therapeutic_response(user_input)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def gpt_response(query):
|
| 291 |
+
therapist = OmaniTherapistAI()
|
| 292 |
+
response = therapist.generate_therapeutic_response(query)
|
| 293 |
+
print(f"AI Response: {response['response']}")
|
| 294 |
+
print(f"Emotional State: {response['emotional_state']}")
|
| 295 |
+
print(f"Detected Language: {response['detected_language']}")
|
| 296 |
+
return response
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Example usage and testing
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
# Test the system
|
| 302 |
+
therapist = OmaniTherapistAI()
|
| 303 |
+
|
| 304 |
+
# Test scenarios in both languages
|
| 305 |
+
test_scenarios = [
|
| 306 |
+
# Arabic scenarios
|
| 307 |
+
"السلام عليكم، أشعر بالقلق الشديد هذه الأيام",
|
| 308 |
+
"أواجه مشاكل في العمل وأشعر بالضغط",
|
| 309 |
+
"لا أستطيع النوم جيداً ومزاجي متقلب",
|
| 310 |
+
"أريد أن أتحدث عن مشاكلي مع زوجتي",
|
| 311 |
+
"أشعر بالاكتئاب ولا أعرف ماذا أفعل",
|
| 312 |
+
|
| 313 |
+
# English scenarios
|
| 314 |
+
"Hello, I'm feeling very anxious these days",
|
| 315 |
+
"I'm having problems at work and feeling stressed",
|
| 316 |
+
"I can't sleep well and my mood is unstable",
|
| 317 |
+
"I want to talk about my problems with my wife",
|
| 318 |
+
"I feel depressed and don't know what to do",
|
| 319 |
+
|
| 320 |
+
# Code-switching scenarios
|
| 321 |
+
"السلام عليكم، I'm feeling very stressed lately",
|
| 322 |
+
"Hello, أشعر بالقلق and I don't know what to do",
|
| 323 |
+
"My work is مرهق جداً and I can't cope"
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
print("=== OMANI Therapist AI Test ===")
|
| 327 |
+
for i, scenario in enumerate(test_scenarios, 1):
|
| 328 |
+
print(f"\n--- Test Scenario {i} ---")
|
| 329 |
+
print(f"User: {scenario}")
|
| 330 |
+
|
| 331 |
+
response = therapist.generate_therapeutic_response(scenario)
|
| 332 |
+
print(f"AI Response: {response['response']}")
|
| 333 |
+
print(f"Emotional State: {response['emotional_state']}")
|
| 334 |
+
print(f"Detected Language: {response['detected_language']}")
|
| 335 |
+
print("-" * 50)
|
| 336 |
+
|
| 337 |
+
# Print conversation summary
|
| 338 |
+
print("\n=== Session Summary ===")
|
| 339 |
+
summary = therapist.get_conversation_summary()
|
| 340 |
+
print(json.dumps(summary, indent=2, ensure_ascii=False))
|
src/agenticRAG/graph/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/graph/builder.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, END
|
| 2 |
+
from src.agenticRAG.models.state import AgentState
|
| 3 |
+
from src.agenticRAG.nodes.query_upgrader import query_upgrader_node
|
| 4 |
+
from src.agenticRAG.nodes.query_router import query_router_node
|
| 5 |
+
from src.agenticRAG.nodes.rag_node import rag_node
|
| 6 |
+
from src.agenticRAG.nodes.web_search_node import web_search_node
|
| 7 |
+
from src.agenticRAG.nodes.direct_llm_node import direct_llm_node
|
| 8 |
+
from src.agenticRAG.graph.router import route_query
|
| 9 |
+
|
| 10 |
+
class GraphBuilder:
|
| 11 |
+
"""Builder for the AgenticRAG graph"""
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def create_graph():
|
| 15 |
+
"""Create the LangGraph workflow"""
|
| 16 |
+
|
| 17 |
+
# Initialize graph
|
| 18 |
+
workflow = StateGraph(AgentState)
|
| 19 |
+
|
| 20 |
+
# Add nodes
|
| 21 |
+
workflow.add_node("query_upgrader", query_upgrader_node)
|
| 22 |
+
workflow.add_node("query_router", query_router_node)
|
| 23 |
+
workflow.add_node("rag_path", rag_node)
|
| 24 |
+
workflow.add_node("web_search", web_search_node)
|
| 25 |
+
workflow.add_node("direct_llm", direct_llm_node)
|
| 26 |
+
|
| 27 |
+
# Set entry point
|
| 28 |
+
workflow.set_entry_point("query_upgrader")
|
| 29 |
+
|
| 30 |
+
# Add edges
|
| 31 |
+
workflow.add_edge("query_upgrader", "query_router")
|
| 32 |
+
|
| 33 |
+
# Add conditional edges based on routing decision
|
| 34 |
+
workflow.add_conditional_edges(
|
| 35 |
+
"query_router",
|
| 36 |
+
route_query,
|
| 37 |
+
{
|
| 38 |
+
"rag_path": "rag_path",
|
| 39 |
+
"web_search": "web_search",
|
| 40 |
+
"direct_llm": "direct_llm"
|
| 41 |
+
}
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# All paths end at END
|
| 45 |
+
workflow.add_edge("rag_path", END)
|
| 46 |
+
workflow.add_edge("web_search", END)
|
| 47 |
+
workflow.add_edge("direct_llm", END)
|
| 48 |
+
|
| 49 |
+
# Compile the graph
|
| 50 |
+
return workflow.compile()
|
src/agenticRAG/graph/router.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
from src.agenticRAG.models.state import AgentState
|
| 3 |
+
|
| 4 |
+
def route_query(state: AgentState) -> Literal["rag_path", "web_search", "direct_llm"]:
|
| 5 |
+
"""Route to appropriate path based on decision"""
|
| 6 |
+
route_map = {
|
| 7 |
+
"RAG": "rag_path",
|
| 8 |
+
"WEB": "web_search",
|
| 9 |
+
"DIRECT": "direct_llm"
|
| 10 |
+
}
|
| 11 |
+
return route_map.get(state.route_decision, "direct_llm")
|
src/agenticRAG/main.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import List
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
from src.agenticRAG.models.state import AgentState
|
| 5 |
+
from src.agenticRAG.models.schemas import QueryRequest, QueryResponse
|
| 6 |
+
from src.agenticRAG.graph.builder import GraphBuilder
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
class AgenticRAGSystem:
|
| 10 |
+
"""Main AgenticRAG system"""
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# Validate settings
|
| 14 |
+
settings.validate()
|
| 15 |
+
|
| 16 |
+
# Create graph
|
| 17 |
+
self.app = GraphBuilder.create_graph()
|
| 18 |
+
|
| 19 |
+
logger.info("AgenticRAG system initialized successfully")
|
| 20 |
+
|
| 21 |
+
def process_query(self, query: str) -> QueryResponse:
|
| 22 |
+
"""Process a single query"""
|
| 23 |
+
|
| 24 |
+
start_time = time.time()
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
# Initialize state
|
| 28 |
+
initial_state = AgentState(user_query=query)
|
| 29 |
+
|
| 30 |
+
# Run the graph
|
| 31 |
+
final_state = self.app.invoke(initial_state)
|
| 32 |
+
|
| 33 |
+
# Calculate processing time
|
| 34 |
+
processing_time = time.time() - start_time
|
| 35 |
+
|
| 36 |
+
# Create response
|
| 37 |
+
response = QueryResponse(
|
| 38 |
+
query=final_state.user_query,
|
| 39 |
+
upgraded_query=final_state.upgraded_query,
|
| 40 |
+
route_taken=final_state.route_decision,
|
| 41 |
+
response=final_state.final_response,
|
| 42 |
+
metadata=final_state.metadata,
|
| 43 |
+
processing_time=processing_time
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
logger.info(f"Query processed successfully in {processing_time:.2f}s")
|
| 47 |
+
return response
|
| 48 |
+
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"Error processing query: {e}")
|
| 51 |
+
raise
|
| 52 |
+
|
| 53 |
+
def process_batch(self, queries: List[str]) -> List[QueryResponse]:
|
| 54 |
+
"""Process multiple queries"""
|
| 55 |
+
|
| 56 |
+
responses = []
|
| 57 |
+
for query in queries:
|
| 58 |
+
try:
|
| 59 |
+
response = self.process_query(query)
|
| 60 |
+
responses.append(response)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Error processing query '{query}': {e}")
|
| 63 |
+
|
| 64 |
+
return responses
|
| 65 |
+
|
| 66 |
+
def agenticRAGResponse(query: str) -> QueryResponse:
|
| 67 |
+
"""Function to get response for a single query"""
|
| 68 |
+
|
| 69 |
+
system = AgenticRAGSystem()
|
| 70 |
+
return system.process_query(query)
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
"""Main function"""
|
| 74 |
+
|
| 75 |
+
# Initialize system
|
| 76 |
+
system = AgenticRAGSystem()
|
| 77 |
+
|
| 78 |
+
# Test queries
|
| 79 |
+
test_queries = [
|
| 80 |
+
"What is machine learning?",
|
| 81 |
+
"Latest news about AI",
|
| 82 |
+
"Write a poem about spring"
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# Process queries
|
| 86 |
+
for query in test_queries:
|
| 87 |
+
print(f"\n{'='*50}")
|
| 88 |
+
print(f"Query: {query}")
|
| 89 |
+
print(f"{'='*50}")
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
response = system.process_query(query)
|
| 93 |
+
|
| 94 |
+
print(f"Original Query: {response.query}")
|
| 95 |
+
print(f"Upgraded Query: {response.upgraded_query}")
|
| 96 |
+
print(f"Route Taken: {response.route_taken}")
|
| 97 |
+
print(f"Response: {response.response}")
|
| 98 |
+
print(f"Processing Time: {response.processing_time:.2f}s")
|
| 99 |
+
print(f"Metadata: {response.metadata}")
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error: {e}")
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
main()
|
src/agenticRAG/models/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/models/schemas.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
class QueryRequest(BaseModel):
|
| 5 |
+
"""Request schema for query processing"""
|
| 6 |
+
query: str
|
| 7 |
+
session_id: Optional[str] = None
|
| 8 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 9 |
+
|
| 10 |
+
class QueryResponse(BaseModel):
|
| 11 |
+
"""Response schema for query processing"""
|
| 12 |
+
query: str
|
| 13 |
+
upgraded_query: str
|
| 14 |
+
route_taken: str
|
| 15 |
+
response: str
|
| 16 |
+
metadata: Dict[str, Any]
|
| 17 |
+
processing_time: float
|
| 18 |
+
|
| 19 |
+
class ProcessingMetadata(BaseModel):
|
| 20 |
+
"""Metadata for processing steps"""
|
| 21 |
+
upgrade_success: bool = False
|
| 22 |
+
routing_success: bool = False
|
| 23 |
+
path_success: bool = False
|
| 24 |
+
errors: List[str] = []
|
| 25 |
+
processing_time: float = 0.0
|
src/agenticRAG/models/state.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
class AgentState(BaseModel):
|
| 5 |
+
"""State schema for the AgenticRAG workflow"""
|
| 6 |
+
|
| 7 |
+
user_query: str = Field(description="Original user query")
|
| 8 |
+
upgraded_query: str = Field(default="", description="Enhanced query")
|
| 9 |
+
route_decision: str = Field(default="", description="Routing decision")
|
| 10 |
+
retrieved_docs: List[str] = Field(default_factory=list, description="Retrieved documents")
|
| 11 |
+
search_results: List[str] = Field(default_factory=list, description="Web search results")
|
| 12 |
+
final_response: str = Field(default="", description="Final response")
|
| 13 |
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
|
| 14 |
+
|
| 15 |
+
class Config:
|
| 16 |
+
"""Pydantic configuration"""
|
| 17 |
+
arbitrary_types_allowed = True
|
src/agenticRAG/nodes/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/nodes/direct_llm_node.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.agenticRAG.models.state import AgentState
|
| 2 |
+
from src.agenticRAG.components.llm_factory import LLMFactory
|
| 3 |
+
from src.agenticRAG.prompt.prompts import Prompts
|
| 4 |
+
|
| 5 |
+
class DirectLLMNode:
|
| 6 |
+
"""Node for direct LLM processing"""
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.llm = LLMFactory.get_llm()
|
| 10 |
+
self.prompt = Prompts.DIRECT_RESPONSE
|
| 11 |
+
|
| 12 |
+
def process_direct_llm(self, state: AgentState) -> AgentState:
|
| 13 |
+
"""Process direct LLM path"""
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
chain = self.prompt | self.llm
|
| 17 |
+
|
| 18 |
+
response = chain.invoke({"query": state.upgraded_query})
|
| 19 |
+
state.final_response = response.content
|
| 20 |
+
state.metadata["direct_llm_success"] = True
|
| 21 |
+
|
| 22 |
+
except Exception as e:
|
| 23 |
+
state.final_response = "Sorry, I couldn't process your request at the moment."
|
| 24 |
+
state.metadata["direct_llm_success"] = False
|
| 25 |
+
state.metadata["direct_llm_error"] = str(e)
|
| 26 |
+
|
| 27 |
+
return state
|
| 28 |
+
|
| 29 |
+
# Node function for LangGraph
|
| 30 |
+
def direct_llm_node(state: AgentState) -> AgentState:
|
| 31 |
+
"""Node function for direct LLM processing"""
|
| 32 |
+
direct_processor = DirectLLMNode()
|
| 33 |
+
return direct_processor.process_direct_llm(state)
|
src/agenticRAG/nodes/query_router.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal
|
| 2 |
+
from src.agenticRAG.models.state import AgentState
|
| 3 |
+
from src.agenticRAG.components.llm_factory import LLMFactory
|
| 4 |
+
from src.agenticRAG.prompt.prompts import Prompts
|
| 5 |
+
from src.config.settings import settings
|
| 6 |
+
|
| 7 |
+
class QueryRouter:
|
| 8 |
+
"""Node for routing queries to appropriate paths"""
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.llm = LLMFactory.get_llm()
|
| 12 |
+
self.prompt = Prompts.QUERY_ROUTER
|
| 13 |
+
|
| 14 |
+
def route_query(self, state: AgentState) -> AgentState:
|
| 15 |
+
"""Route query to appropriate path"""
|
| 16 |
+
|
| 17 |
+
chain = self.prompt | self.llm
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
response = chain.invoke({"query": state.upgraded_query})
|
| 21 |
+
route_decision = response.content.strip().upper()
|
| 22 |
+
|
| 23 |
+
# Validate route decision
|
| 24 |
+
if route_decision not in ["RAG", "WEB", "DIRECT"]:
|
| 25 |
+
route_decision = settings.DEFAULT_ROUTE
|
| 26 |
+
|
| 27 |
+
state.route_decision = route_decision
|
| 28 |
+
state.metadata["routing_success"] = True
|
| 29 |
+
|
| 30 |
+
except Exception as e:
|
| 31 |
+
state.route_decision = settings.DEFAULT_ROUTE
|
| 32 |
+
state.metadata["routing_success"] = False
|
| 33 |
+
state.metadata["routing_error"] = str(e)
|
| 34 |
+
|
| 35 |
+
return state
|
| 36 |
+
|
| 37 |
+
# Node function for LangGraph
|
| 38 |
+
def query_router_node(state: AgentState) -> AgentState:
|
| 39 |
+
"""Node function for query routing"""
|
| 40 |
+
router = QueryRouter()
|
| 41 |
+
return router.route_query(state)
|
src/agenticRAG/nodes/query_upgrader.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.agenticRAG.models.state import AgentState
|
| 2 |
+
from src.agenticRAG.components.llm_factory import LLMFactory
|
| 3 |
+
from src.agenticRAG.prompt.prompts import Prompts
|
| 4 |
+
from src.config.settings import settings
|
| 5 |
+
|
| 6 |
+
class QueryUpgrader:
|
| 7 |
+
"""Node for upgrading user queries"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.llm = LLMFactory.get_llm()
|
| 11 |
+
self.prompt = Prompts.QUERY_UPGRADER
|
| 12 |
+
|
| 13 |
+
def upgrade_query(self, state: AgentState) -> AgentState:
|
| 14 |
+
"""Upgrade/enhance the user query"""
|
| 15 |
+
|
| 16 |
+
chain = self.prompt | self.llm
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
response = chain.invoke({"query": state.user_query})
|
| 20 |
+
upgraded_query = response.content.strip()
|
| 21 |
+
|
| 22 |
+
# Fallback to original if upgrade fails
|
| 23 |
+
if not upgraded_query or len(upgraded_query) > settings.MAX_QUERY_LENGTH:
|
| 24 |
+
upgraded_query = state.user_query
|
| 25 |
+
|
| 26 |
+
state.upgraded_query = upgraded_query
|
| 27 |
+
state.metadata["upgrade_success"] = True
|
| 28 |
+
|
| 29 |
+
except Exception as e:
|
| 30 |
+
state.upgraded_query = state.user_query
|
| 31 |
+
state.metadata["upgrade_success"] = False
|
| 32 |
+
state.metadata["upgrade_error"] = str(e)
|
| 33 |
+
|
| 34 |
+
return state
|
| 35 |
+
|
| 36 |
+
# Node function for LangGraph
|
| 37 |
+
def query_upgrader_node(state: AgentState) -> AgentState:
|
| 38 |
+
"""Node function for query upgrading"""
|
| 39 |
+
upgrader = QueryUpgrader()
|
| 40 |
+
return upgrader.upgrade_query(state)
|
src/agenticRAG/nodes/rag_node.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.agenticRAG.models.state import AgentState
|
| 2 |
+
from src.agenticRAG.components.llm_factory import LLMFactory
|
| 3 |
+
from src.agenticRAG.components.vectorstore import VectorStoreManager
|
| 4 |
+
from src.agenticRAG.prompt.prompts import Prompts
|
| 5 |
+
|
| 6 |
+
class RAGNode:
|
| 7 |
+
"""Node for RAG processing"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.llm = LLMFactory.get_llm()
|
| 11 |
+
self.vectorstore_manager = VectorStoreManager()
|
| 12 |
+
self.prompt = Prompts.RAG_RESPONSE
|
| 13 |
+
|
| 14 |
+
# Load vectorstore
|
| 15 |
+
self.vectorstore_manager.load_vectorstore()
|
| 16 |
+
|
| 17 |
+
def process_rag(self, state: AgentState) -> AgentState:
|
| 18 |
+
"""Process RAG path - retrieve from knowledge base"""
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# Retrieve documents
|
| 22 |
+
docs = self.vectorstore_manager.search_documents(state.upgraded_query, k=3)
|
| 23 |
+
state.retrieved_docs = docs
|
| 24 |
+
|
| 25 |
+
# Generate response with retrieved context
|
| 26 |
+
chain = self.prompt | self.llm
|
| 27 |
+
|
| 28 |
+
context = "\n".join(docs) if docs else "No relevant documents found."
|
| 29 |
+
response = chain.invoke({
|
| 30 |
+
"query": state.upgraded_query,
|
| 31 |
+
"context": context
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
state.final_response = response.content
|
| 35 |
+
state.metadata["rag_success"] = True
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
state.final_response = "Sorry, I couldn't retrieve information from the knowledge base."
|
| 39 |
+
state.metadata["rag_success"] = False
|
| 40 |
+
state.metadata["rag_error"] = str(e)
|
| 41 |
+
|
| 42 |
+
return state
|
| 43 |
+
|
| 44 |
+
# Node function for LangGraph
|
| 45 |
+
def rag_node(state: AgentState) -> AgentState:
|
| 46 |
+
"""Node function for RAG processing"""
|
| 47 |
+
rag_processor = RAGNode()
|
| 48 |
+
return rag_processor.process_rag(state)
|
src/agenticRAG/nodes/web_search_node.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.agenticRAG.models.state import AgentState
|
| 2 |
+
from src.agenticRAG.components.llm_factory import LLMFactory
|
| 3 |
+
from src.agenticRAG.components.search_tools import SearchToolFactory
|
| 4 |
+
from src.agenticRAG.prompt.prompts import Prompts
|
| 5 |
+
|
| 6 |
+
class WebSearchNode:
|
| 7 |
+
"""Node for web search processing"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.llm = LLMFactory.get_llm()
|
| 11 |
+
self.search_tool = SearchToolFactory.get_search_tool()
|
| 12 |
+
self.prompt = Prompts.WEB_RESPONSE
|
| 13 |
+
|
| 14 |
+
def process_web_search(self, state: AgentState) -> AgentState:
|
| 15 |
+
"""Process web search path"""
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
# Perform web search
|
| 19 |
+
search_results = self.search_tool.run(state.upgraded_query)
|
| 20 |
+
state.search_results = [search_results]
|
| 21 |
+
|
| 22 |
+
# Generate response with search results
|
| 23 |
+
chain = self.prompt | self.llm
|
| 24 |
+
|
| 25 |
+
response = chain.invoke({
|
| 26 |
+
"query": state.upgraded_query,
|
| 27 |
+
"search_results": search_results
|
| 28 |
+
})
|
| 29 |
+
|
| 30 |
+
state.final_response = response.content
|
| 31 |
+
state.metadata["web_search_success"] = True
|
| 32 |
+
|
| 33 |
+
except Exception as e:
|
| 34 |
+
state.final_response = "Sorry, I couldn't perform web search at the moment."
|
| 35 |
+
state.metadata["web_search_success"] = False
|
| 36 |
+
state.metadata["web_search_error"] = str(e)
|
| 37 |
+
|
| 38 |
+
return state
|
| 39 |
+
|
| 40 |
+
# Node function for LangGraph
|
| 41 |
+
def web_search_node(state: AgentState) -> AgentState:
|
| 42 |
+
"""Node function for web search processing"""
|
| 43 |
+
web_processor = WebSearchNode()
|
| 44 |
+
return web_processor.process_web_search(state)
|
src/agenticRAG/prompt/__init__.py
ADDED
|
File without changes
|
src/agenticRAG/prompt/prompts.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
+
# import json
|
| 3 |
+
|
| 4 |
+
# class Prompts:
|
| 5 |
+
# """Centralized prompt templates"""
|
| 6 |
+
|
| 7 |
+
# QUERY_UPGRADER = ChatPromptTemplate.from_messages([
|
| 8 |
+
# ("system", """You are a query enhancement specialist. Your task is to improve user queries for better information retrieval.
|
| 9 |
+
|
| 10 |
+
# Enhancement guidelines:
|
| 11 |
+
# 1. Add relevant keywords and synonyms
|
| 12 |
+
# 2. Clarify ambiguous terms
|
| 13 |
+
# 3. Expand abbreviations and acronyms
|
| 14 |
+
# 4. Add context when missing
|
| 15 |
+
# 5. Maintain original intent
|
| 16 |
+
# 6. Keep enhanced query concise (under 200 characters)
|
| 17 |
+
|
| 18 |
+
# Return only the enhanced query, nothing else."""),
|
| 19 |
+
# ("human", "Original query: {query}")
|
| 20 |
+
# ])
|
| 21 |
+
|
| 22 |
+
# QUERY_ROUTER = ChatPromptTemplate.from_messages([
|
| 23 |
+
# ("system", """You are a query router. Analyze the query and decide which path to take:
|
| 24 |
+
|
| 25 |
+
# PATHS:
|
| 26 |
+
# 1. "RAG" - For queries about specific knowledge base content, documents, or domain expertise
|
| 27 |
+
# 2. "WEB" - For current events, real-time information, recent news, or trending topics
|
| 28 |
+
# 3. "DIRECT" - For general conversation, creative tasks, opinions, or reasoning without specific facts
|
| 29 |
+
|
| 30 |
+
# DECISION CRITERIA:
|
| 31 |
+
# - RAG: Domain-specific questions, technical documentation, specific facts from knowledge base
|
| 32 |
+
# - WEB: Questions with temporal keywords (latest, current, recent, today), current events, real-time data
|
| 33 |
+
# - DIRECT: General chat, creative writing, opinions, mathematical reasoning, casual conversation
|
| 34 |
+
|
| 35 |
+
# Respond with only one word: RAG, WEB, or DIRECT"""),
|
| 36 |
+
# ("human", "Query: {query}")
|
| 37 |
+
# ])
|
| 38 |
+
|
| 39 |
+
# RAG_RESPONSE = ChatPromptTemplate.from_messages([
|
| 40 |
+
# ("system", """You are a helpful assistant. Answer the user's question based on the provided context from the knowledge base.
|
| 41 |
+
|
| 42 |
+
# Context: {context}
|
| 43 |
+
|
| 44 |
+
# If the context doesn't contain relevant information, say so clearly."""),
|
| 45 |
+
# ("human", "{query}")
|
| 46 |
+
# ])
|
| 47 |
+
|
| 48 |
+
# WEB_RESPONSE = ChatPromptTemplate.from_messages([
|
| 49 |
+
# ("system", """You are a helpful assistant. Answer the user's question based on the provided web search results.
|
| 50 |
+
|
| 51 |
+
# Search Results: {search_results}
|
| 52 |
+
|
| 53 |
+
# Provide a comprehensive answer based on the search results. If the results don't contain relevant information, say so clearly."""),
|
| 54 |
+
# ("human", "{query}")
|
| 55 |
+
# ])
|
| 56 |
+
|
| 57 |
+
# DIRECT_RESPONSE = ChatPromptTemplate.from_messages([
|
| 58 |
+
# ("system", """You are a helpful AI assistant. Answer the user's question directly using your knowledge and reasoning capabilities.
|
| 59 |
+
|
| 60 |
+
# Be conversational, accurate, and helpful. If you're unsure about something, acknowledge the uncertainty."""),
|
| 61 |
+
# ("human", "{query}")
|
| 62 |
+
# ])
|
| 63 |
+
|
| 64 |
+
# def load_data_relative():
|
| 65 |
+
# """Load data.json using relative path"""
|
| 66 |
+
# try:
|
| 67 |
+
# with open("knowledge_base_metadata.json", 'r', encoding='utf-8') as f:
|
| 68 |
+
# data = json.load(f)
|
| 69 |
+
# description = ""
|
| 70 |
+
# for key in data:
|
| 71 |
+
# description +=f"{key['description']}\n"
|
| 72 |
+
# return description
|
| 73 |
+
# except FileNotFoundError:
|
| 74 |
+
# print("data.json not found in current directory")
|
| 75 |
+
# return None
|
| 76 |
+
# except json.JSONDecodeError as e:
|
| 77 |
+
# print(f"Error decoding JSON: {e}")
|
| 78 |
+
# return None
|
| 79 |
+
|
| 80 |
+
# if __name__=="__main__":
|
| 81 |
+
# print(load_data_relative())
|
| 82 |
+
|
| 83 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 84 |
+
import json
|
| 85 |
+
|
| 86 |
+
class Prompts:
|
| 87 |
+
"""Centralized prompt templates"""
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def load_kb_description():
|
| 91 |
+
"""Dynamically load knowledge base descriptions"""
|
| 92 |
+
try:
|
| 93 |
+
with open("knowledge_base_metadata.json", 'r', encoding='utf-8') as f:
|
| 94 |
+
data = json.load(f)
|
| 95 |
+
description = ""
|
| 96 |
+
for item in data:
|
| 97 |
+
description += f"- {item.get('description', '').strip()}\n"
|
| 98 |
+
return description.strip()
|
| 99 |
+
except FileNotFoundError:
|
| 100 |
+
return "No knowledge base found."
|
| 101 |
+
except json.JSONDecodeError as e:
|
| 102 |
+
return f"Error decoding knowledge base: {e}"
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def query_router(cls):
|
| 106 |
+
"""Return QUERY_ROUTER with dynamic KB info"""
|
| 107 |
+
kb_description = cls.load_kb_description()
|
| 108 |
+
return ChatPromptTemplate.from_messages([
|
| 109 |
+
("system", f"""You are a query router. Analyze the query and decide which path to take:
|
| 110 |
+
|
| 111 |
+
PATHS:
|
| 112 |
+
1. "RAG" - For queries about specific knowledge base content, documents, or domain expertise
|
| 113 |
+
2. "WEB" - For current events, real-time information, recent news, or trending topics
|
| 114 |
+
3. "DIRECT" - For general conversation, creative tasks, opinions, or reasoning without specific facts
|
| 115 |
+
|
| 116 |
+
DECISION CRITERIA:
|
| 117 |
+
- RAG: Domain-specific questions, technical documentation, specific facts from knowledge base
|
| 118 |
+
- WEB: Questions with temporal keywords (latest, current, recent, today), current events, real-time data
|
| 119 |
+
- DIRECT: General chat, creative writing, opinions, mathematical reasoning, casual conversation
|
| 120 |
+
|
| 121 |
+
Knowledge Base contains:
|
| 122 |
+
{kb_description}
|
| 123 |
+
|
| 124 |
+
Respond with only one word: RAG, WEB, or DIRECT"""),
|
| 125 |
+
("human", "Query: {{query}}")
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
# Other prompts stay unchanged
|
| 129 |
+
QUERY_UPGRADER = ChatPromptTemplate.from_messages([
|
| 130 |
+
("system", """You are a query enhancement specialist..."""), # shortened for brevity
|
| 131 |
+
("human", "Original query: {query}")
|
| 132 |
+
])
|
| 133 |
+
|
| 134 |
+
RAG_RESPONSE = ChatPromptTemplate.from_messages([
|
| 135 |
+
("system", """You are a helpful assistant. Answer the user's question based on the provided context from the knowledge base.
|
| 136 |
+
|
| 137 |
+
Context: {context}
|
| 138 |
+
|
| 139 |
+
If the context doesn't contain relevant information, say so clearly."""),
|
| 140 |
+
("human", "{query}")
|
| 141 |
+
])
|
| 142 |
+
|
| 143 |
+
WEB_RESPONSE = ChatPromptTemplate.from_messages([
|
| 144 |
+
("system", """You are a helpful assistant. Answer the user's question based on the provided web search results.
|
| 145 |
+
|
| 146 |
+
Search Results: {search_results}
|
| 147 |
+
|
| 148 |
+
Provide a comprehensive answer..."""),
|
| 149 |
+
("human", "{query}")
|
| 150 |
+
])
|
| 151 |
+
|
| 152 |
+
DIRECT_RESPONSE = ChatPromptTemplate.from_messages([
|
| 153 |
+
("system", """You are a helpful AI assistant. Answer the user's question directly using your knowledge and reasoning capabilities."""),
|
| 154 |
+
("human", "{query}")
|
| 155 |
+
])
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
prompt = Prompts.query_router()
|
| 159 |
+
print(prompt.format(query="What is the architecture of the Omani AI system?"))
|
src/config/__init__.py
ADDED
|
File without changes
|
src/config/settings.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
class Settings:
|
| 7 |
+
"""Configuration settings for the AgenticRAG system"""
|
| 8 |
+
|
| 9 |
+
# API Keys
|
| 10 |
+
GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "")
|
| 11 |
+
GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "")
|
| 12 |
+
GOOGLE_CSE_ID: str = os.getenv("GOOGLE_CSE_ID", "")
|
| 13 |
+
|
| 14 |
+
# Model Configuration
|
| 15 |
+
GROQ_MODEL: str = "llama3-8b-8192"
|
| 16 |
+
GROQ_TEMPERATURE: float = 0.1
|
| 17 |
+
|
| 18 |
+
OPENAI_MODEL: str = "gpt-4.1-nano-2025-04-14"
|
| 19 |
+
OPENAI_TEMPERATURE: float = 0.3
|
| 20 |
+
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
| 21 |
+
|
| 22 |
+
# Embedding Model
|
| 23 |
+
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 24 |
+
OPENAI_EMBEDDING_MODEL = "text-embedding-3-large"
|
| 25 |
+
# Vector Store
|
| 26 |
+
VECTORSTORE_PATH: str = "data/vectorstore"
|
| 27 |
+
|
| 28 |
+
# Search Configuration
|
| 29 |
+
SEARCH_RESULTS_COUNT: int = 5
|
| 30 |
+
|
| 31 |
+
SERPER_API_KEY: str = os.getenv("SERPER_API_KEY", "")
|
| 32 |
+
TAVILY_API_KEY: str = os.getenv("TAVILY_API_KEY", "")
|
| 33 |
+
|
| 34 |
+
# Query Enhancement
|
| 35 |
+
MAX_QUERY_LENGTH: int = 200
|
| 36 |
+
|
| 37 |
+
# Routing Configuration
|
| 38 |
+
DEFAULT_ROUTE: str = "DIRECT"
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def validate(cls) -> bool:
|
| 42 |
+
"""Validate required settings"""
|
| 43 |
+
required_keys = ["GROQ_API_KEY"]
|
| 44 |
+
for key in required_keys:
|
| 45 |
+
if not getattr(cls, key):
|
| 46 |
+
raise ValueError(f"Missing required setting: {key}")
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
settings = Settings()
|