diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..a2a54cfecb0bbe48713239ba8c861576b937c160 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,17 @@ +.git +.gitignore +.env +__pycache__ +*.pyc +*.pyo +venv/ +logs/ +data/database/backups/ +tests/ +docs/ +README.md +README_UPDATES.md +*.md +.github/ +.pytest_cache/ +htmlcov/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c3a1619cc02885420b7abd58300ccaea32e4755d --- /dev/null +++ b/.gitignore @@ -0,0 +1,69 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Virtual environment +.env +.venv/ +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Environment variables +.env + +# VS Code settings +.vscode/ + +# MacOS system files +.DS_Store + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Logs +*.log + +# Cache and temp files +*.tmp +*.swp +*.bak +.cache/ +*.sqlite3 +*.db + +# Data files +*.pdf +*.json +*.jsonl + +# Output folders +dist/ +build/ +*.egg-info/ + +# Output data +data/ + +# Pyright config may differ from one platform to another +pyrightconfig.json + +# Pycharm +.idea/ + +# OS junk +.Trashes.env +.env +.env + +#idk +--source-branch +--source-repo +/.gradio/certificate.pem + +#feedback I just uploaded into the same file to check for accuracy +chatbot emba x.docx +IEBMA Test Cards 1_2.docx diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0c9413d1f3579e2d88bff4f724ff5d3eabd15221 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +# ============================== Initial Building ============================= +FROM python:3.11.14-slim-bookworm AS builder + +WORKDIR /app + +# CPU-only PyTorch +RUN pip install --no-cache-dir torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cpu + +# Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt +# ============================== Size Reduction =============================== +FROM python:3.11.14-slim-bookworm + +WORKDIR /app + +# Only necessary dependencies from builder +COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin + +# System dependencies for runtime +RUN apt-get update && apt-get install -y --no-install-recommends \ + libmagic1 \ + poppler-utils \ + curl \ + && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + +# ============================ Final Compilation ============================== +COPY . . + +EXPOSE 7860 + +HEALTHCHECK --interval=60s --timeout=10s --retries=3 \ + CMD curl -f http://localhost:7860/health || exit 1 + +CMD ["python", "main.py", "--app", "de"] diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..29962c3127075afe473916005881dcb6bd472792 --- /dev/null +++ b/config.py @@ -0,0 +1,186 @@ +""" +Configuration settings for the Executive Education RAG Chatbot. +PLEASE CONSIDER READING THE 'docs/configuration_system_documentation.md' TO PROPERLY USE THE NEW CONFIGURATION SYSTEM. +""" +# ========================================= General Configuration =========================================== + +# A list of ISO 639 language codes. Defines a list of languages in which +# the application can operate. Defaults to ['en', 'de']. +AVAILABLE_LANGUAGES = ['en', 'de'] + +# A string representing a path (relative to the project root or absolute) to the directory +# where the data output files such as scraping or document processing outputs will be stored. +DATA_PATH = 'data' + +# A string representing a path (relative to the project root or absolute) to the directory +# where the loging files will be stored. +LOGS_PATH = 'logs' + +# =================================== Conversation State Configuration ====================================== + +# A boolean; either True or False. Enables the collection of user preferences +# during conversation to avoid repetetive questions. Defaults to True. +TRACK_USER_PROFILE = True + +# An integer. Defines the amount of user messages after which the language +# of the conversation will be locked. If set to 0, the language will not be locked. +LOCK_LANGUAGE_AFTER_N_MESSAGES = 3 + +# An integer. Sets the maximum amount of conversation turns as the sum of user queries +# and agent responses. The conversation ends after the maximum turns amount is reached. +MAX_CONVERSATION_TURNS = 20 + +# ============================================ LLM Configuration ============================================ + +# A string, either 'openai', 'groq', 'open_router' or 'ollama' (local). +# Defines the main model provider for the application. +LLM_PROVIDER = 'openai' + +# A string. Defines the model that will be used by the application agents. +OPENAI_MODEL = 'gpt-5.1' +# GROQ_MODEL = +# OLLAMA_MODEL = +# OPEN_ROUTER_MODEL = + +# ==================================== Weaviate Database Configuration ====================================== + +# A boolean; either True or False. +# Defines whether the database is set as a local instance (via Docker container), +# or as a cloud service. More information on https://docs.weaviate.io/weaviate. +WEAVIATE_IS_LOCAL = False + +# A string. Defines the name of the colletions stored in the database. +# For each available language a new collection will be created +# with set name _. +WEAVIATE_COLLECTION_BASENAME = 'hsg_rag_content' + +# A string; either 'manual', 'filesystem' (local instance), 's3' (AWS). +# Defines the service for storing the database backups. +# More information on https://docs.weaviate.io/deploy/configuration/backups. +WEAVIATE_BACKUP_METHOD = 'manual' + +# A string representing a path in the system where backups will be stored +# only if WEAVIATE_BACKUP_METHOD is set to 'manual'. +BACKUPS_PATH = 'data/database/backups' + +# A string representing a system path where collection properties will be stored. +PROPERTIES_PATH = 'data/database' + +# A string representing a system path where property strategies will be stored. +# More information on property strategies in the documentation. +STRATEGIES_PATH = 'data/database/strategies' + +# An integer. Defines a connection timeout to the cloud weaviate service (in seconds). +# Defaults to 90. +WEAVIATE_INIT_TIMEOUT = 90 + +# An integer. Defines the query response time limit upon querying the database (in seconds). +# Defaults to 60. +WEAVIATE_QUERY_TIMEOUT = 60 + +# An integer. Defines the chunk insertion time limit when importing new chunks to database (in seconds). +# Defaults to 600 +WEAVIATE_INSERT_TIMEOUT = 600 + +# ========================================== Cache Configuration ============================================ + +# A string; either 'local', 'cloud' (Redis) or 'dict'. Defaults to 'cloud'. +# Sets the default cache mode. More information on cache modes in documentation. +CACHE_MODE = 'cloud' + +# An integer. Sets the reset time (time to live) in seconds for the cache storage. +# The cache storage will be cleared upon reset time exceedance. +# Defaults to 86400 seconds (24 hours). +CACHE_TTL = 86400 + +# An integer. Maximum amount of cached messages that will be held in the cache storage. +# Defaults to 1000. +CACHE_MAX_SIZE = 1000 + +# A string. Defines the IP adress to access the local cache storage. Defaults to 'localhost'. +CACHE_LOCAL_HOST = 'localhost' + +# An integer. Defines the port for accessing the local cache storage. Defaults to 6379. +CACHE_LOCAL_PORT = 6379 + +# ===================================== Data Processing Configuration ======================================= + +# A string representing the name of an embeding model for embedding generation. +# The parameter MAX_TOKENS must match this model's maximum token amount. +EMBEDDING_MODEL = 'sentence-transformers/multi-qa-mpnet-base-dot-v1' + +# A float in range from 0 to 1. Sets the threshold for english language in the language detector. +# If the language detection certanty is lower than the threshold, the English language will be returned. +LANG_AMBIGUITY_THRESHOLD = 0.6 + +# An integer. Defines the maximum amount of tokens pro single chunk. +MAX_TOKENS = 512 + +# An integer. Defines the amount of overlapping tokens between chunks to keep the context. +CHUNK_OVERLAP = 100 + + +# An integer representing seconds. Defines the maximum waiting time for the target server +# responses during the scraping procedures. +SCRAPING_TIMEOUT = 30 + +# An integer. Defines the maximum amount of additional tries that will be performed +# if the initial request to the server failed. +SCRAPING_MAX_RETRIES = 3 + +# An integer representing seconds. Defines the waiting interval between two server calls. +# This value might be overwritten by the delay set by the server. +SCRAPING_CRAWL_DELAY = 1 + +# An integer. Defines the backoff base value for retries with exponential backoff. +# The higher is the number, the longer is the waiting interval between subsequent retries going to be. +SCRAPING_BACKOFF_RATE = 1.25 + +# A list of string URLs. Defines the starting points for the website scraping. +SCRAPING_TARGET_URLS = [ + # 'https://emba.unisg.ch/', # EMBA HSG root + 'https://embax.ch/', # emba X root +] + +# Scraping Priority Interval in days +SCRAPING_PRIO_INTERVAL = { + "high": 1, + "medium": 7, + "low": 30 +} + +# ======================================== Agent Chain Configuration ======================================== + +# A boolean; either True or False. Activates the response quality evaluation procedure +# for agentic responses. Defaults to True. +ENABLE_EVALUATE_RESPONSE_QUALITY = True + +# A float in range from 0 to 1. Sets the treshold value for the quality evaluation. +# The fallback mechanism will be activated if the quality of the agentic response +# is lower than the confidence threshold. +CONFIDENCE_THRESHOLD = 0.6 + +# An integer. Defines the amount of chunks that should be retrieved from the database +# upon querying by subagents during conversation. Defaults to 4. +TOP_K_RETRIEVAL = 4 + +# An integer. Sets the amount of model invocation retries after which the fallback model +# will be invoked. Defaults to 3. +MODEL_MAX_RETRIES = 3 + +# An integer. Sets the maximum amount of words in the response from the lead agent. +MAX_RESPONSE_WORDS_LEAD = 100 + +# An integer. Sets the maximum amount of words in the response for subagents. +MAX_RESPONSE_WORDS_SUBAGENT = 200 + +# A boolean; either True or False. If response chunking is enabled, long responses +# from the lead agent will be split and retuned through multiple conversation turns. +ENABLE_RESPONSE_CHUNKING = True + +# ========================================== Notification Configuration ===================================== + +NOTIFY_ENABLE_EMAIL_ALERTS= True +NOTIFY_ENABLE_SLACK_ALERTS = True + +# =========================================================================================================== diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..c05557c90073d39cd41cc8d180195a047a0b1f11 --- /dev/null +++ b/main.py @@ -0,0 +1,163 @@ +""" +Main entry point for the Executive Education RAG Chatbot. +""" +import argparse +import langsmith +from langsmith import traceable +from src.utils.logging import init_logging, get_logger +from config import AVAILABLE_LANGUAGES +from src.cache.cache import Cache +from src.config import config + + +# Initialize logging +def logging_startup(): + init_logging() + return get_logger('main_module') + + +def run_scraper() -> None: + """ + Run the scraper to collect program data. + + Args: + use_selenium: Whether to use Selenium for scraping. + """ + from src.pipeline.pipeline import ImportPipeline + logger = logging_startup() + + logger.info("Running scraper...") + ImportPipeline().scrape_website() + logger.info("Scraping completed.") + + +def run_importer(sources: list[str]) -> None: + """Run the data import pipeline.""" + from src.pipeline.pipeline import ImportPipeline + logger = logging_startup() + + logger.info("Running data import pipeline...") + ImportPipeline().import_many_documents(sources) + logger.info("Data processing completed.") + + +def run_weaviate_command(command: str, backup_id: str = None): + """Run commands to manipulate the database contents.""" + from src.database.weavservice import WeaviateService + logger = logging_startup() + + logger.info(f"Running database command {command}") + if command == 'restore' and not backup_id: + logger.error("Backup ID is required to initalize the restore process.") + + service = WeaviateService() + if command == 'backup': + service._create_backup() + + if command == 'restore': + service._restore_backup(backup_id) + + if command == 'delete' or command == 'redo': + service._delete_collections() + + if command == 'init' or command == 'redo': + service._create_collections() + + if command == 'checkhealth' or command == 'init' or command == 'redo': + service._checkhealth() + + +def clear_cache(): + cache = Cache.get_cache() + if cache: + cache.clear_cache() + + +def run_application(lang: str, cache_mode, cache) -> None: + """Run the chatbot web application.""" + from src.apps.chat.app import ChatbotApplication + logger = logging_startup() + + Cache.configure(cache_mode, cache) + + logger.info("Starting chatbot web application...") + app = ChatbotApplication(language=lang) + app.run() + + +def run_dbapp() -> None: + """Run the database application.""" + from src.apps.dbapp.app import DatabaseApplication + logger = logging_startup() + logger.info("Starting database application...") + app = DatabaseApplication() + app.run() + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="University of St. Gallen Executive Education RAG Chatbot") + + # Add arguments + parser.add_argument("--scrape", action="store_true", + help="Scrapes the data from the HSG website and imports it into the database") + parser.add_argument("--imports", nargs="+", help="Runs the data importing pipeline for the provided files") + + parser.add_argument("--weaviate", type=str, choices=['init', 'delete', 'redo', 'checkhealth', 'backup', 'restore'], + help="Runs different database actions") + parser.add_argument("--backup-id", type=str, help="Required when calling the --weaviate restore command!") + + parser.add_argument("--cache-mode", type=str, choices=['local', 'cloud', 'dict'], default=config.cache.CACHE_MODE, + help="Defines whether to use the local or cloud Redis database or the special python dict as cache") + + parser.add_argument("--cache", action="store_false", help="(De-)activates the caching mechanism") + + parser.add_argument("--clear-cache", action="store_true", + help="Clears the cache") + + parser.add_argument("--cli", action="store_true", help="Run the chatbot CLI") + parser.add_argument("--app", type=str, choices=AVAILABLE_LANGUAGES, help="Run the chatbot web application") + parser.add_argument("--dbapp", action="store_true", help="Run the database management application") + + return parser.parse_args() + + +def main(): + """Main entry point for the application.""" + args = parse_args() + + # Load cache settings with the cache args + must_clear_cache = False + + # Check if any argument is provided + if not any([args.scrape, args.imports, args.weaviate, args.cli, args.cache, args.app, args.dbapp]): + # If no argument is provided, run the chatbot by default + run_application(cache_mode=args.cache_mode, cache=args.cache) + return + + # Run the specified components + if args.scrape: + must_clear_cache = True + run_scraper() + + if args.imports: + must_clear_cache = True + run_importer(args.imports) + + if args.weaviate: + if args.weaviate in ["init", "redo", "restore"]: + must_clear_cache = True + run_weaviate_command(command=args.weaviate, backup_id=args.backup_id) + + if args.clear_cache or must_clear_cache: + clear_cache() + + if args.app: + run_application(lang=args.app, cache_mode=args.cache_mode, cache=args.cache) + + if args.dbapp: + run_dbapp() + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..62e9eb2b318d0f0a8ec6e63459bec1765ed63aa6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,39 @@ +# Core dependencies +langchain>=1.0.2 +langchain-core>=1.0.1 +langchain-deepseek>=1.0.0 +langchain-groq>=1.0.0 +langchain-ollama>=0.3.10 +langchain-openai>=1.0.1 +langsmith>=0.4.0 + +requests>=2.31.0 +openai>=1.3.0 +python-dotenv>=1.0.0 +colorama>=0.4.6 + +# Language detection +langdetect>=1.0.9 + +# Transformers for tokenization +transformers>=4.34.0 + +# Web applications +gradio>=5.49.1 + +# Processing pipeline +docling>=2.55.0 +ultimate-sitemap-parser>=1.8.0 +beautifulsoup4>=4.14.3 +fake-useragent>=1.5.1 + +# Weaviate Vector DB +weaviate-client>=4.16.9 +PyYAML>=6.0 + +# Cache +cachetools>=5.0.0 +redis>=4.5.5 + +# Scheduling +apscheduler diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/apps/__init__.py b/src/apps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/apps/chat/__init__.py b/src/apps/chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/apps/chat/app.py b/src/apps/chat/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d061033bbb84e5fc93a1afe94ffa15dabd63a346 --- /dev/null +++ b/src/apps/chat/app.py @@ -0,0 +1,324 @@ +import uuid +import gradio as gr +from fastapi import FastAPI +from datetime import datetime + + +from src.const.agent_response_constants import * +from src.const.data_consent_constants import * +from src.rag.agent_chain import ExecutiveAgentChain +from src.utils.logging import get_logger, ConsentLogger + +logger = get_logger("chatbot_app") + +def init_fastapi_app(language): + fastapi_app = FastAPI() + + @fastapi_app.get('/health') + def healthcheck(): + from src.database.weavservice import WeaviateService + from fastapi.responses import JSONResponse + + status = 200 + message = { 'timestamp': datetime.now().isoformat() } + try: + message |= { + 'status': 'ok', + 'weaviate': True, + } + response = WeaviateService().ping(language) + if response['status'] != 'OK': + status = 503 + message |= { + 'status': 'degraded', + 'weaviate': False, + 'error': str(response['error']), + } + except Exception as e: + status = 503 + message |= { + 'status': 'down', + 'weaviate': False, + 'error': str(e), + } + + return JSONResponse( + status_code = status, + content = message, + ) + + return fastapi_app + + +class ChatbotApplication: + def __init__(self, language: str = "de") -> None: + self._fastapi_app = init_fastapi_app(language) + self._gradio_app = gr.Blocks() + self._app = gr.mount_gradio_app(self._fastapi_app, self._gradio_app, path='/') + self._language = language + self._consentLogger = ConsentLogger() + + with self._gradio_app: + agent_state = gr.State(None) + lang_state = gr.State(language) + consent_state = gr.State(False) + session_id_state = gr.State(str(uuid.uuid4())) # for consent logging later + + with gr.Row(): + lang_selector = gr.Radio( + choices=["Deutsch", "English"], + value="English" if language == "en" else "Deutsch", + label="Selected Language", + interactive=True, + ) + reset_button = gr.Button("Reset Conversation", visible=False) + + # ---- Consent Screen (Page 1) ---- + with gr.Column(visible=True) as consent_screen: + data_policy = gr.Markdown(PRIVACY_NOTICE[language]) + with gr.Row(): + decline_btn = gr.Button(DECLINE[language]) + accept_btn = gr.Button(ACCEPT[language]) + + decline_info = gr.Markdown("", visible=False) + + # ---- Chat Screen (Page 2) ---- + with gr.Column(visible=False) as chat_screen: + chat = gr.ChatInterface( + fn=lambda msg, history, agent: self._chat( + message=msg, history=history, agent=agent + ), + additional_inputs=[agent_state], + title="Executive Education Adviser", + ) + + with gr.Row(): + withdraw_button = gr.Button(WITHDRAW_TEXT[language], visible=False, variant="stop") + + def create_session_id() -> str: + return str(uuid.uuid4()) + + def initialize_agent(lang: str, session_id: str): + agent = ExecutiveAgentChain(language=lang, session_id=session_id) + greeting = agent.generate_greeting() + + disclaimer_html = get_disclaimer_widget(lang) + + full_content = f"{disclaimer_html}{greeting}" + + return agent, [{"role": "assistant", "content": full_content}] + + def label_to_lang_code(label: str) -> str: + return "en" if label == "English" else "de" + + # Language change: before consent => only update consent UI text. + # After consent: keep chat running (or optionally re-init agent on language change). + def on_language_change( + language_label: str, + consent_given: bool, + agent, + session_id: str, + ): + lang_code = label_to_lang_code(language_label) + + # Before consent: update consent screen text to selected language + if not consent_given: + return ( + lang_code, + gr.update(value=PRIVACY_NOTICE[lang_code]), + gr.update(value=DECLINE[lang_code]), + gr.update(value=ACCEPT[lang_code]), + gr.update(visible=False, value=""), + None, # agent_state stays None + None, # chat stays as it is + gr.update(value=WITHDRAW_TEXT[lang_code], visible=False), + ) + + # After consent + new_agent, greeting = initialize_agent(lang_code, session_id=session_id) + return ( + lang_code, + gr.update(value=PRIVACY_NOTICE[lang_code]), + gr.update(value=DECLINE[lang_code]), + gr.update(value=ACCEPT[lang_code]), + gr.update(visible=False, value=""), + new_agent, + greeting, + gr.update(value=WITHDRAW_TEXT[lang_code], visible=True), + ) + + def on_accept(lang: str, session_id: str): + agent, greeting = initialize_agent(lang, session_id=session_id) + self._consentLogger.log(session_id, "accepted", policy_version="1.0") + self._language = lang + return ( + gr.update(visible=False), # consent_screen hide + gr.update(visible=True), # chat_screen show + True, # consent_state + agent, # agent_state + greeting, # chat initial history + gr.update(visible=False, value=""), # decline_info hide + gr.update(visible=True), # show reset_button + gr.update(value=WITHDRAW_TEXT[lang], visible=True), + ) + + def on_decline(lang: str, session_id: str): + self._language = lang + self._consentLogger.log(session_id, "declined", policy_version="1.0") + return ( + gr.update(visible=True), # consent_screen stays + gr.update(visible=False), # chat_screen stays hidden + False, # consent_state + None, # agent_state + [], # chat history empty + gr.update(visible=True, value=DECLINE_MESSAGE[lang]), + ) + + def on_reset_chat(lang: str, session_id: str): + agent, greeting = initialize_agent(lang, session_id=session_id) + self._language = lang + return ( + agent, + greeting, + ) + + def on_withdraw(lang: str, agent, session_id: str): + self._consentLogger.log(session_id, "withdrawn", policy_version="1.0") + + # 1) wipe server-side + if agent is not None: + try: + agent.wipe_session_data() + logger.info("wipe_session_data executed") + except Exception as e: + logger.error(f"wipe_session_data failed: {e}", exc_info=True) + + # 2) lock chat again (back to consent screen) + new_session_id = create_session_id() + return ( + gr.update(visible=True), # consent_screen + gr.update(value=PRIVACY_NOTICE[lang]), # data_policy + gr.update(value=DECLINE[lang]), # decline_btn + gr.update(value=ACCEPT[lang]), # accept_btn + gr.update(visible=False), # chat_screen + gr.update(visible=True, value=WITHDRAW_CONFIRMATION_MESSAGE[lang]), # decline_info + False, # consent_state + None, # agent_state + [], # chat.chatbot_value (history) + gr.update(visible=False), # reset_button + gr.update(visible=False), # withdraw_button + new_session_id, # session_id_state + ) + + # Language switch updates consent UI if consent not given + lang_selector.change( + fn=on_language_change, + inputs=[lang_selector, consent_state, agent_state, session_id_state], + outputs=[lang_state, + data_policy, + decline_btn, + accept_btn, + decline_info, + agent_state, + chat.chatbot_value, + withdraw_button, + ], + queue=True, + ) + + # Accept/Decline data consent + accept_btn.click( + fn=on_accept, + inputs=[lang_state, session_id_state], + outputs=[ + consent_screen, + chat_screen, + consent_state, + agent_state, + chat.chatbot_value, + decline_info, + reset_button, + withdraw_button, + ], + queue=True, + ) + + decline_btn.click( + fn=on_decline, + inputs=[lang_state, session_id_state], + outputs=[consent_screen, chat_screen, consent_state, agent_state, chat.chatbot_value, decline_info], + queue=True, + ) + + # Reset + reset_button.click( + fn=on_reset_chat, + inputs=[lang_state, session_id_state], + outputs=[ + agent_state, + chat.chatbot_value, + ], + queue=True, + ) + + # Withdraw consent + withdraw_button.click( + fn=on_withdraw, + inputs=[lang_state, agent_state, session_id_state], + outputs=[ + consent_screen, + data_policy, + decline_btn, + accept_btn, + chat_screen, + decline_info, + consent_state, + agent_state, + chat.chatbot_value, + reset_button, + withdraw_button, + session_id_state, + ], + queue=True, + ) + + + @property + def app(self) -> gr.Blocks: + """Expose underlying Gradio Blocks for external runners (e.g., HF Spaces).""" + return self._app + + def _chat(self, message: str, history: list[dict], agent: ExecutiveAgentChain): + if agent is None: + logger.error("Agent not initialized") + return ["I apologize, but the chatbot is not properly initialized."] + + answers = [] + try: + logger.info(f"Processing user query: {message[:100]}...") + response = agent.query(message) + answers.append(response.response) + self._language = response.language + + if response.show_booking_widget: + html_code = get_booking_widget(language=self._language, programs=response.relevant_programs) + answers.append(gr.HTML(value=html_code)) + except Exception as e: + logger.error(f"Error processing query: {e}", exc_info=True) + error_message = ( + "I apologize, but I encountered an error processing your request. " + "Please try rephrasing your question or contact our admissions team for assistance." + ) + answers.append(error_message) + + return answers + + + def run(self): + import uvicorn + uvicorn.run( + self._app, + host='0.0.0.0', + port=7860, + log_config=None + ) diff --git a/src/apps/dbapp/app.py b/src/apps/dbapp/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e44133cb2e7370785d39db550a2b32eab2f0753c --- /dev/null +++ b/src/apps/dbapp/app.py @@ -0,0 +1,44 @@ +from tkinter import * +from tkinter import ttk +from src.database.weavservice import WeaviateService + +from src.apps.dbapp.mainframe import MainFrame +from src.apps.dbapp.query import QueryFrame +from src.apps.dbapp.imports import ImportFrame +from src.apps.dbapp.backup import BackupsFrame +from src.apps.dbapp.collections import CollectionsFrame +from src.apps.dbapp.config import SchemaConfigurationFrame + +from src.utils.logging import get_logger + +logger = get_logger("db_inter ") + +class DatabaseApplication: + def __init__(self) -> None: + self._root = Tk() + self._service = WeaviateService() + + self._root.title("Database Interface") + self._root.geometry("810x500") + + notebook = ttk.Notebook(self._root) + notebook.pack(fill=BOTH, expand=True) + + main_frame = MainFrame(notebook, self._service).init() + import_frame = ImportFrame(notebook, self._service).init() + config_frame = SchemaConfigurationFrame(notebook, self._service).init() + collections_frame = CollectionsFrame(notebook, self._service).init() + query_frame = QueryFrame(notebook, self._service).init() + backups_frame = BackupsFrame(notebook, self._service).init() + + notebook.add(main_frame, text='Main') + notebook.add(import_frame, text='Import') + notebook.add(config_frame, text='Schemas') + notebook.add(collections_frame, text='Collections') + notebook.add(query_frame, text='Query') + notebook.add(backups_frame, text='Backups') + + logger.info("Application initialization finished") + + def run(self): + self._root.mainloop() diff --git a/src/apps/dbapp/backup.py b/src/apps/dbapp/backup.py new file mode 100644 index 0000000000000000000000000000000000000000..62a2d08ed23f991d4ac4bdc78d85b394c233dc0e --- /dev/null +++ b/src/apps/dbapp/backup.py @@ -0,0 +1,191 @@ +import os, shutil +from datetime import datetime + +from tkinter import * +from tkinter import ttk +from src.database.weavservice import WeaviateService +from src.apps.dbapp.framebase import CustomFrameBase +from src.apps.dbapp.utilclasses import BackupData +from src.config import config + +def _load_backup_files(): + backups = [] + os.makedirs(config.weaviate.BACKUP_PATH, exist_ok=True) + + for backup_id in os.listdir(config.weaviate.BACKUP_PATH): + backups.append(BackupData(backup_id)) + + return backups + +class BackupsFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService): + super().__init__(parent, service) + self._backups = _load_backup_files() + + def init(self) -> ttk.Frame: + self._backups = _load_backup_files() + + main_frame = ttk.Frame(self._parent) + main_frame.pack(fill=BOTH, expand=True) + + tree_frame = ttk.Frame(main_frame) + tree_frame.pack(fill=BOTH, expand=True, padx=10, pady=10) + + label_frame = ttk.Frame(main_frame) + label_frame.pack(fill=X, expand=True, padx=10, pady=10) + + button_frame = ttk.Frame(main_frame) + button_frame.pack(fill=X, padx=10, pady=10) + + date_reverse_sort = True + columns = ('date', 'size') + + info_label = ttk.Label(label_frame, text="", padding=8) + + def _print_label(msg, backc, forc): + info_label.configure(text=msg, foreground=forc, background=backc) + info_label.update_idletasks() + + def print_failure(msg: str): + _print_label(msg, "#FFCDD2", "#B71C1C") + + def print_info(msg: str): + _print_label(msg, "#cdedff", "#1c31b7") + + def print_success(msg: str): + _print_label(msg, "#d7ffcd", "#4db71c") + + + tree = ttk.Treeview( + tree_frame, + columns=columns, + show='tree headings', + selectmode='browse', + ) + + def sort_by_date(): + nonlocal date_reverse_sort + + parents = tree.get_children("") + data = [] + + for p in parents: + value = tree.set(p, 'date') + try: + value = datetime.strptime(value, "%d.%m.%Y %H:%M:%S") + except Exception: + pass + data.append((value, p)) + + data.sort(reverse=date_reverse_sort) + date_reverse_sort = not date_reverse_sort + + for index, (_, p) in enumerate(data): + tree.move(p, "", index) + + tree.heading( + 'date', + text='Created at ' + ('▾' if date_reverse_sort else '▴'), + command=lambda: sort_by_date() + ) + + tree.heading('#0', text='Backup ID') + tree.heading('date', text='Created at ▾', command=lambda: sort_by_date()) + tree.heading('size', text='Embeddings amount') + + tree.column("#0", width=100) + tree.column("date", width=60) + tree.column("size", width=30) + + def insert_backup(backup): + nonlocal date_reverse_sort + bk = backup.to_treeformat() + parent = tree.insert('', 0 if not date_reverse_sort else END, + text=bk['id'], + values=bk['date'] + ) + for collection in bk['collections']: + tree.insert(parent, END, + text=collection['name'], + values=collection['size'], + ) + + for backup in self._backups: + insert_backup(backup) + sort_by_date() + + def create_backup(): + print_info(f"Creating new backup...") + backup_id = self._service._create_backup() + + backup = BackupData(backup_id) + self._backups.append(backup) + insert_backup(backup) + print_success(f"Successfully created new backup {backup._backup_id}!") + + def restore_backup(): + item_id = tree.selection()[0] + backup = tree.item(item_id) + + print_info(f"Restoring backup {backup['text']}...") + self._service._restore_backup('backup_' + backup['text']) + print_success(f"Successfully restored backup {backup['text']}!") + + def delete_backup(): + item_id = tree.selection()[0] + backup = tree.item(item_id) + + backup_path = os.path.join(config.weaviate.BACKUP_PATH, 'backup_' + backup['text']) + shutil.rmtree(backup_path, ignore_errors=True) + + tree.delete(item_id) + print_success(f"Deleted backup {backup['text']}.") + + + create_bkp_btn = ttk.Button( + button_frame, + text="Create Backup", + command=create_backup + ) + + restore_bkp_btn = ttk.Button( + button_frame, + text="Restore Backup", + command=restore_backup, + state=['disabled'] + ) + + delete_bkp_btn = ttk.Button( + button_frame, + text="Delete Backup", + command=delete_backup, + state=['disabled'] + ) + + def on_item_selection(event): + selected = tree.selection() + if not selected: + restore_bkp_btn.state(['disabled']) + delete_bkp_btn.state(['disabled']) + return + + item_id = selected[0] + is_parent = tree.parent(item_id) == '' + restore_bkp_btn.state(['!disabled' if is_parent else 'disabled']) + delete_bkp_btn.state(['!disabled' if is_parent else 'disabled']) + + tree.bind("<>", on_item_selection) + + scrollbar = ttk.Scrollbar(tree_frame, orient="vertical", command=tree.yview) + tree.configure(yscrollcommand=scrollbar.set) + + info_label.pack() + + tree.pack(side=LEFT, fill=BOTH, expand=True) + scrollbar.pack(side=RIGHT, fill=Y) + + create_bkp_btn.pack(side=LEFT, padx=5) + restore_bkp_btn.pack(side=RIGHT, padx=5) + delete_bkp_btn.pack(side=RIGHT, padx=5) + + return main_frame diff --git a/src/apps/dbapp/collections.py b/src/apps/dbapp/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..6ae4a55bed2f00121c019be10b7006ae0b44f990 --- /dev/null +++ b/src/apps/dbapp/collections.py @@ -0,0 +1,8 @@ +from tkinter import * +from tkinter import ttk +from src.apps.dbapp.framebase import CustomFrameBase +from src.database.weavservice import WeaviateService + +class CollectionsFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService) -> None: + super().__init__(parent, service) diff --git a/src/apps/dbapp/config.py b/src/apps/dbapp/config.py new file mode 100644 index 0000000000000000000000000000000000000000..759d93c56f3dd1ba36b71af0cfcc06644ae733d0 --- /dev/null +++ b/src/apps/dbapp/config.py @@ -0,0 +1,350 @@ +import os, json + +from tkinter import * +from tkinter import ttk +from src.apps.dbapp.framebase import CustomFrameBase +from src.utils.stratutils.generator import generate_strategy +from src.database.weavservice import WeaviateService +from src.config import config + +def _dump_schema(schema): + os.makedirs(config.weaviate.PROPERTIES_PATH, exist_ok=True) + properties_file_path = os.path.join(config.weaviate.PROPERTIES_PATH, 'properties.json') + with open(properties_file_path, 'w', encoding='utf-8') as f: + json.dump(schema, f, indent=2, default=str) + + +class SchemaConfigurationFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService) -> None: + super().__init__(parent, service) + self._schema = self._load_schema_data() + self._strategies = self._load_strategies() + + + def _load_strategies(self) -> dict: + os.makedirs(config.weaviate.STRATEGIES_PATH, exist_ok=True) + loaded_strats = os.listdir(config.weaviate.STRATEGIES_PATH) + strategies = {} + + for name, prop in self._schema.items(): + strategy_file = f"strat_{name}.py" + file_path = os.path.join(config.weaviate.STRATEGIES_PATH, strategy_file) + strategy_content = "" + + if strategy_file not in loaded_strats: + strategy_content = generate_strategy(name, prop) + with open(file_path, 'w', encoding='utf-8') as f: + f.write(strategy_content) + else: + with open(file_path) as f: + strategy_content = f.read() + + strategies[name] = strategy_content + + return strategies + + + def _save_strategy(self, name, strategy) -> None: + os.makedirs(config.weaviate.STRATEGIES_PATH, exist_ok=True) + self._strategies[name] = strategy + + file_path = os.path.join(config.weaviate.STRATEGIES_PATH, f"strat_{name}.py") + with open(file_path, 'w', encoding='utf-8') as f: + f.write(strategy) + + + def _load_schema_data(self) -> dict: + schema_data = {} + + schema = self._service._extract_data()['schema'] + if not schema: + return schema_data + + for prop in schema[0]['properties']: + data_property = { + 'description': prop.get('description', ''), + 'data_type': prop['dataType'][0], + 'filterable': prop['indexFilterable'], + 'searchable': prop['indexSearchable'], + 'skip_vectorization': prop['moduleConfig']['text2vec-huggingface']['skip'], + } + schema_data[prop['name']] = data_property + + _dump_schema(schema_data) + + return schema_data + + + def _update_schema_property(self, old_name: str, new_name: str, prop: dict) -> None: + del self._schema[old_name] + self._schema[new_name] = prop + _dump_schema(self._schema) + + + def _add_schema_property(self, name, prop: dict) -> None: + self._schema[name] = prop + _dump_schema(self._schema) + + + def _delete_schema_property(self, name) -> None: + del self._schema[name] + _dump_schema(self._schema) + + + def init(self) -> ttk.Frame: + main_frame = ttk.Frame(self._parent) + main_frame.pack(fill=BOTH, expand=True) + + schema_frame = ttk.Frame(main_frame) + schema_frame.pack(fill=BOTH, expand=True) + + add_button = ttk.Button(schema_frame, text='Add property', + command=lambda: self._add_property(refresh_table)) + add_button.pack(anchor=NW, padx=5, pady=5) + + canvas = Canvas(schema_frame) + scrollbar = ttk.Scrollbar(schema_frame, orient="vertical", command=canvas.yview) + scrollable_frame = ttk.Frame(canvas) + + scrollable_frame.bind("", lambda _: canvas.configure(scrollregion=canvas.bbox("all"))) + canvas.create_window((0, 0), window=scrollable_frame, anchor="nw") + canvas.configure(yscrollcommand=scrollbar.set) + canvas.pack(side=LEFT, fill=BOTH, expand=True) + scrollbar.pack(side=RIGHT, fill=Y) + + def refresh_table(): + for widget in scrollable_frame.winfo_children(): + widget.destroy() + + self._build_table(scrollable_frame, refresh_table) + + refresh_table() + return main_frame + + + def _build_table(self, parent_frame, refresh_callback): + style = ttk.Style() + style.configure('Header.TLabel', font=('Helvetica', 10, 'bold'), background='#e0e0e0') + style.configure('EvenRow.TLabel', background='#f0f0f0') + style.configure('OddRow.TLabel', background='white') + + table_frame = ttk.Frame(parent_frame) + table_frame.pack(fill=X, padx=5, pady=5) + + for i in range(5): + table_frame.grid_columnconfigure(i, minsize=100, weight=1) + + headers = ['Name', 'Data Type', 'Filterable', 'Searchable', 'Skip Vectorize'] + for col, text in enumerate(headers): + label = ttk.Label(table_frame, text=text, borderwidth=1, relief=SOLID, anchor='center', style='Header.TLabel') + label.grid(row=0, column=col, sticky='ew') + + for idx, (name, prop) in enumerate(self._schema.items(), start=1): + row_style = 'EvenRow.TLabel' if idx % 2 == 0 else 'OddRow.TLabel' + + row_name_label = ttk.Label(table_frame, text=name, style=row_style) + row_type_label = ttk.Label(table_frame, text=prop['data_type'].upper(), style=row_style) + row_filterable_label = ttk.Label(table_frame, text='Yes' if prop['filterable'] else 'No', style=row_style) + row_searchable_label = ttk.Label(table_frame, text='Yes' if prop['searchable'] else 'No', style=row_style) + row_vectorize_label = ttk.Label(table_frame, text='Yes' if prop['skip_vectorization'] else 'No', style=row_style) + + row_edit_button = ttk.Button(table_frame, text='Edit', + command=lambda n=name, p=prop: self._edit_property(n, p, refresh_callback)) + row_delete_button = ttk.Button(table_frame, text='Delete', + command=lambda n=name: self._delete_property(n, refresh_callback)) + row_strategy_button = ttk.Button(table_frame, text='Strategy', + command=lambda n=name: self._handle_strategy(n)) + + row_name_label.grid(row=idx, column=0, sticky='ew', ipadx=25) + row_type_label.grid(row=idx, column=1, sticky='ew', ipadx=25) + row_filterable_label.grid(row=idx, column=2, sticky='ew', ipadx=25) + row_searchable_label.grid(row=idx, column=3, sticky='ew') + row_vectorize_label.grid(row=idx, column=4, sticky='ew') + row_edit_button.grid(row=idx, column=5, sticky='ew') + row_delete_button.grid(row=idx, column=6, sticky='ew') + row_strategy_button.grid(row=idx, column=7, sticky='ew') + + + def _handle_strategy(self, n): + dialog = Toplevel() + dialog.title(f"Property {n} strategy") + dialog.geometry("700x400") + + field_frame = ttk.Frame(dialog) + field_frame.pack(fill=BOTH, expand=True, padx=10, pady=10) + + scrollbar = Scrollbar(field_frame, orient=VERTICAL) + scrollbar.pack(side=RIGHT, fill=Y) + + strategy = self._strategies[n] + edit_field = Text(field_frame, width=80, height=15, wrap=WORD, yscrollcommand=scrollbar.set) + edit_field.insert(END, strategy) + edit_field.pack(side=LEFT, fill=BOTH, expand=True) + + scrollbar.config(command=edit_field.yview) + + def commit(): + new_strategy = edit_field.get("1.0", END).strip() + self._save_strategy(n, new_strategy) + dialog.destroy() + + + ttk.Button(dialog, text="Save", command=commit).pack(side=BOTTOM, anchor=S, pady=10) + + + def _delete_property(self, name, refresh_callback): + msg = f"Do you want to delete property '{name}'?" + dialog = Toplevel() + dialog.title('Warning!') + dialog.geometry(f"{len(msg)*5+120}x50") + dialog.grab_set() + + ttk.Label(dialog, text=msg).pack() + + def submit(): + self._delete_schema_property(name) + refresh_callback() + dialog.destroy() + + button_frame = ttk.Frame(dialog) + button_frame.pack(fill=X, expand=True) + + ttk.Button(button_frame, text='Delete', command=submit).pack(side=LEFT, padx=15) + ttk.Button(button_frame, text='Cancel', command=dialog.destroy).pack(side=RIGHT, padx=15) + + + def _add_property(self, refresh_callback): + dialog = Toplevel() + dialog.title(f"New property") + dialog.geometry("280x300") + dialog.grab_set() + + texts_frame = ttk.Frame(dialog) + texts_frame.pack(fill=X, expand=True) + + ttk.Label(texts_frame, text="Name:").grid(row=0, column=0, padx=5, pady=5, sticky='e') + name_entry = ttk.Entry(texts_frame) + name_entry.grid(row=0, column=1, padx=5, pady=5, sticky='w') + + ttk.Label(texts_frame, text="Description:").grid(row=1, column=0, padx=5, pady=5, sticky='e') + desc_entry = ttk.Entry(texts_frame) + desc_entry.insert(0, '') + desc_entry.grid(row=1, column=1, padx=5, pady=5, sticky='w') + + ttk.Label(texts_frame, text="Data Type:").grid(row=2, column=0, padx=5, pady=5, sticky='e') + type_var = StringVar(value='text') + type_combo = ttk.Combobox(texts_frame, textvariable=type_var, + values=["text", "int", "number", "boolean", "date", "text[]", "int[]", "number[]", "boolean[]", "date[]", "object"] + ) + type_combo.grid(row=2, column=1, padx=5, pady=5, sticky='w') + + checks_frame = ttk.Frame(dialog) + checks_frame.pack(fill=X, expand=True) + + filterable_var = BooleanVar(value=True) + searchable_var = BooleanVar(value=True) + skip_vec_var = BooleanVar(value=False) + + ttk.Checkbutton(checks_frame, text="Filterable ", variable=filterable_var).pack(anchor=W, padx=15) + ttk.Checkbutton(checks_frame, text="Searchable ", variable=searchable_var).pack(anchor=W, padx=15) + ttk.Checkbutton(checks_frame, text="Skip Vectorization", variable=skip_vec_var).pack(anchor=W, padx=15) + + def submit(): + name = name_entry.get() + if not name: + self._show_messagebox("Parameter 'name' is required!") + return + if name in self._schema.keys(): + self._show_messagebox(f"Property with name '{name}' already exists!") + return + + prop = { + 'description': desc_entry.get().strip(), + 'data_type': type_var.get(), + 'filterable': filterable_var.get(), + 'searchable': searchable_var.get(), + 'skip_vectorization': skip_vec_var.get(), + } + + self._add_schema_property(name, prop) + refresh_callback() + dialog.destroy() + + buttons_frame = ttk.Frame(dialog) + buttons_frame.pack(fill=X, expand=True) + + ttk.Button(buttons_frame, text="Save", command=submit).pack(side=LEFT, padx=15) + ttk.Button(buttons_frame, text="Cancel", command=dialog.destroy).pack(side=RIGHT, padx=15) + + + def _edit_property(self, name: str, prop: dict, refresh_callback): + dialog = Toplevel() + dialog.title(f"Edit Property: {name}") + dialog.geometry("280x300") + dialog.grab_set() + + texts_frame = ttk.Frame(dialog) + texts_frame.pack(fill=X, expand=True) + + ttk.Label(texts_frame, text="Name:").grid(row=0, column=0, padx=5, pady=5, sticky='e') + name_entry = ttk.Entry(texts_frame) + name_entry.insert(0, name) + name_entry.grid(row=0, column=1, padx=5, pady=5, sticky='w') + + ttk.Label(texts_frame, text="Description:").grid(row=1, column=0, padx=5, pady=5, sticky='e') + desc_entry = ttk.Entry(texts_frame) + desc_entry.insert(0, prop.get('description', '')) + desc_entry.grid(row=1, column=1, padx=5, pady=5, sticky='w') + + ttk.Label(texts_frame, text="Data Type:").grid(row=2, column=0, padx=5, pady=5, sticky='e') + type_var = StringVar(value=prop['data_type']) + type_combo = ttk.Combobox(texts_frame, textvariable=type_var, + values=["text", "int", "number", "boolean", "date", "text[]", "int[]", "number[]", "boolean[]", "date[]", "object"] + ) + type_combo.grid(row=2, column=1, padx=5, pady=5, sticky='w') + + checks_frame = ttk.Frame(dialog) + checks_frame.pack(fill=X, expand=True) + + filterable_var = BooleanVar(value=prop['filterable']) + searchable_var = BooleanVar(value=prop['searchable']) + skip_vec_var = BooleanVar(value=prop['skip_vectorization']) + + ttk.Checkbutton(checks_frame, text="Filterable ", variable=filterable_var).pack(anchor=W, padx=15) + ttk.Checkbutton(checks_frame, text="Searchable ", variable=searchable_var).pack(anchor=W, padx=15) + ttk.Checkbutton(checks_frame, text="Skip Vectorization", variable=skip_vec_var).pack(anchor=W, padx=15) + + def submit(): + new_name = name_entry.get().strip() + if not new_name: + self._show_messagebox("Parameter 'name' is required!") + return + + updated_prop = { + 'description': desc_entry.get().strip(), + 'data_type': type_var.get(), + 'filterable': filterable_var.get(), + 'searchable': searchable_var.get(), + 'skip_vectorization': skip_vec_var.get(), + } + + self._update_schema_property(name, new_name, updated_prop) + refresh_callback() + dialog.destroy() + + buttons_frame = ttk.Frame(dialog) + buttons_frame.pack(fill=X, expand=True) + + ttk.Button(buttons_frame, text="Save", command=submit).pack(side=LEFT, padx=15) + ttk.Button(buttons_frame, text="Cancel", command=dialog.destroy).pack(side=RIGHT, padx=15) + + + @staticmethod + def _show_messagebox(msg): + dialog = Toplevel() + dialog.title('Warning!') + dialog.geometry(f"{len(msg)*5+120}x50") + dialog.grab_set() + + ttk.Label(dialog, text=msg).pack() + ttk.Button(dialog, text='OK', command=dialog.destroy).pack(padx=15) diff --git a/src/apps/dbapp/framebase.py b/src/apps/dbapp/framebase.py new file mode 100644 index 0000000000000000000000000000000000000000..d350ce4a42ad4d7b75eec6f96d79ab9cc417799c --- /dev/null +++ b/src/apps/dbapp/framebase.py @@ -0,0 +1,15 @@ +from tkinter import * +from tkinter import ttk +from src.database.weavservice import WeaviateService + +class CustomFrameBase: + def __init__(self, parent, service: WeaviateService) -> None: + self._parent = parent + self._service = service + + + def init(self) -> ttk.Frame: + main_frame = ttk.Frame(self._parent) + main_frame.pack() + + return main_frame diff --git a/src/apps/dbapp/imports.py b/src/apps/dbapp/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..922b9e426dd702814edeafb5419cb12c3ef4f96e --- /dev/null +++ b/src/apps/dbapp/imports.py @@ -0,0 +1,244 @@ +import os +import threading +from tkinter import * +from tkinter import ttk +from tkinter import filedialog +from queue import Queue + +from .framebase import CustomFrameBase + +from src.pipeline.pipeline import ImportPipeline +from src.pipeline.utils import ProcessingResult + +from src.database.weavservice import WeaviateService +from src.utils.lang import get_language_name +from src.config import config + +class ImportFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService) -> None: + super().__init__(parent, service) + self._import_paths = dict() + + def init(self) -> ttk.Frame: + main_frame = ttk.Frame(self._parent) + main_frame.pack(fill=BOTH, expand=True) + + # ====================== Helper functions ====================== + def update_treeview(): + for item in self.files_treeview.get_children(): + self.files_treeview.delete(item) + for filename in self._import_paths: + self.files_treeview.insert("", 0, text=filename) + + def open_file_dialog(): + filepaths = filedialog.askopenfilenames( + title="Select files to import", + filetypes=(("PDF", "*.pdf"), ("Text files", "*.txt"), ("All files", "*.*")) + ) + for path in filepaths: + filename = os.path.basename(path) + self._import_paths[filename] = path + update_treeview() + + def remove_files(): + selection = self.files_treeview.selection() + if not selection: + return + for item in selection: + filename = self.files_treeview.item(item)["text"] + self._import_paths.pop(filename, None) + update_treeview() + + def change_button_state(state): + add_button.config(state=state) + remove_button.config(state=state) + import_button.config(state=state) + + # Configure grid for 50/50 split + main_frame.grid_rowconfigure(0, weight=1) + main_frame.grid_columnconfigure(0, weight=1) + main_frame.grid_columnconfigure(1, weight=1) + + # ====================== LEFT SIDE ====================== + left_frame = ttk.Frame(main_frame) + left_frame.grid(row=0, column=0, sticky='nsew', padx=(10, 5), pady=10) + + # Button row for add/remove + btn_row = ttk.Frame(left_frame) + btn_row.pack(fill=X, pady=(0, 8)) + + add_button = ttk.Button(btn_row, text="Add files", command=open_file_dialog) + add_button.pack(side=LEFT, padx=8) + + remove_button = ttk.Button(btn_row, text="Remove files", command=remove_files) + remove_button.pack(side=LEFT, padx=8) + + # Controls row for checkbox and import button + controls_row = ttk.Frame(left_frame) + controls_row.pack(fill=X, pady=(0, 8)) + + import_button = ttk.Button( + controls_row, + text="Begin Import", + command=lambda: self._import_callback(change_button_state) + ) + import_button.pack(side=LEFT, padx=10) + + self.reset_cd_var = BooleanVar(value=False) + reset_cb = ttk.Checkbutton( + controls_row, + text="Reset database", + variable=self.reset_cd_var + ) + reset_cb.pack(side=LEFT, padx=8, pady=6) + + # Files treeview + self.files_treeview = ttk.Treeview( + left_frame, + columns=[], + show="tree headings", + selectmode="extended", + height=18 + ) + self.files_treeview.heading("#0", text="File name") + self.files_treeview.column("#0", width=260) + self.files_treeview.pack(fill=BOTH, expand=True, pady=8) + + # ====================== RIGHT SIDE ====================== + right_frame = ttk.Frame(main_frame) + right_frame.grid(row=0, column=1, sticky='nsew', padx=(5, 10), pady=10) + + ttk.Label(right_frame, text="Enter URLs (one per line):").pack(anchor=W, padx=5, pady=(0, 6)) + + self.url_text = Text(right_frame, width=28, height=22, undo=True, wrap="word", font=("Segoe UI", 10)) + self.url_text.pack(side=LEFT, fill=BOTH, expand=True, padx=5, pady=5) + + self.url_text.insert(END, '\n'.join(config.get('SCRAPING_TARGET_URLS'))) + + # Scrollbar + scrollbar = ttk.Scrollbar(right_frame, orient="vertical", command=self.url_text.yview) + scrollbar.pack(side=RIGHT, fill=Y) + self.url_text.config(yscrollcommand=scrollbar.set) + + return main_frame + + + def _deduplication_callback(self, source: str, amount: int): + result_queue = Queue() + + def show_dialog(): + dialog = Toplevel() + dialog.title("Duplicated content!") + dialog.bell() + + wrap_width = 360 + + info_label = ttk.Label( + dialog, + text=f'{amount} duplicated chunks found in database for {source}!', + wraplength=wrap_width, + justify=LEFT + ) + info_label2 = ttk.Label( + dialog, + text='Would you like to reimport them with updated properties?', + wraplength=wrap_width, + justify=LEFT + ) + + info_label.pack(fill=X, anchor=W, padx=15, pady=15) + info_label2.pack(fill=X, anchor=W, padx=15, pady=15) + + def reimport_callback(): + result_queue.put(True) + dialog.destroy() + + def dispose_callback(): + result_queue.put(False) + dialog.destroy() + + reimport_button = ttk.Button(dialog, text='Reimport', command=reimport_callback) + dispose_button = ttk.Button(dialog, text='Dispose', command=dispose_callback) + + reimport_button.pack(side=LEFT, padx=15, pady=15) + dispose_button.pack(side=RIGHT, padx=15, pady=15) + + dialog.update_idletasks() + width = dialog.winfo_reqwidth() + 20 + height = dialog.winfo_reqheight() + 20 + dialog.geometry(f"{width}x{height}") + + dialog.protocol("WM_DELETE_WINDOW", dispose_callback) + + dialog.wait_visibility() + dialog.grab_set() + + self._parent.after(0, show_dialog) + return result_queue.get() + + + def _import_callback(self, button_state_callback): + dialog = Toplevel() + dialog.title("Import status") + dialog.geometry("600x400") + + current_import_label = ttk.Label(dialog, text='Initiating the import pipeline...') + current_import_label.pack(side=TOP, padx=15, pady=15) + + progress_bar = ttk.Progressbar(dialog, length=200, value=0, maximum=100) + progress_bar.pack(side=TOP, padx=15, pady=15) + + chunks_treeview = ttk.Treeview( + dialog, + columns=['chunks', 'lang'], + show='tree headings', + selectmode='extended', + ) + chunks_treeview.heading('#0', text='File name') + chunks_treeview.heading('chunks', text='Collected chunks') + chunks_treeview.heading('lang', text='Language') + + chunks_treeview.column('#0', width=100) + chunks_treeview.column('chunks', width=60) + chunks_treeview.column('lang', width=40) + + chunks_treeview.pack(side=TOP, fill=X, padx=15, pady=15, expand=True) + + def logging_callback( + msg: str, + progress: int, + result: ProcessingResult = None, + failed: bool = False, + ): + current_import_label.config(text=msg) + progress_bar.config(value=progress) + + if result: + chunks_treeview.insert('', index=0, + text=result.source, + values=( + 'Failure!' if failed else len(result.chunks), + get_language_name(result.lang) + ) + ) + config.dbapp['logging_callback'] = logging_callback + + def import_task(): + button_state_callback(DISABLED) + filepaths = self._import_paths.values() + urls = self.url_text.get('1.0', END).strip().split('\n') + try: + ImportPipeline( + logging_callback=logging_callback, + deduplication_callback=self._deduplication_callback, + ).import_all( + paths=filepaths, + urls=urls, + reset_collections=self.reset_cd_var.get() + ) + dialog.bell() + finally: + button_state_callback(NORMAL) + + import_thread = threading.Thread(target=import_task) + import_thread.start() diff --git a/src/apps/dbapp/mainframe.py b/src/apps/dbapp/mainframe.py new file mode 100644 index 0000000000000000000000000000000000000000..75ef608ab384ffb6e328f788472892d83736c602 --- /dev/null +++ b/src/apps/dbapp/mainframe.py @@ -0,0 +1,8 @@ +from tkinter import * +from tkinter import ttk +from src.apps.dbapp.framebase import CustomFrameBase +from src.database.weavservice import WeaviateService + +class MainFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService) -> None: + super().__init__(parent, service) diff --git a/src/apps/dbapp/query.py b/src/apps/dbapp/query.py new file mode 100644 index 0000000000000000000000000000000000000000..c021cd9b5f4f7de020b3754223d20b09b532d1b3 --- /dev/null +++ b/src/apps/dbapp/query.py @@ -0,0 +1,108 @@ +from tkinter import * +from tkinter import ttk +from src.apps.dbapp.framebase import CustomFrameBase +from src.database.weavservice import WeaviateService + +class QueryFrame(CustomFrameBase): + def __init__(self, parent, service: WeaviateService) -> None: + super().__init__(parent, service) + + def init(self) -> ttk.Frame: + main_frame = ttk.Frame(self._parent) + main_frame.pack(fill=BOTH, expand=True) + + input_frame = ttk.Frame(main_frame) + input_frame.pack(fill=X, padx=10, pady=(5, 10)) + + self.language_var = StringVar(value="de") + + self.filters_button = ttk.Button(input_frame, text="Filters...", command=self.open_filters) + self.filters_button.pack(side=LEFT, padx=(0, 10)) + + lang_frame = ttk.Frame(input_frame) + lang_frame.pack(side=LEFT, padx=(0, 15)) + + ttk.Radiobutton( + lang_frame, + text="EN", + variable=self.language_var, + value="en" + ).pack(side=LEFT, padx=(0, 8)) + + ttk.Radiobutton( + lang_frame, + text="DE", + variable=self.language_var, + value="de" + ).pack(side=LEFT) + + self.query_entry = ttk.Entry(input_frame) + self.query_entry.pack(side=LEFT, fill=X, expand=True, padx=(0, 10)) + + self.send_button = ttk.Button(input_frame, text="Send", command=self.send_query) + self.send_button.pack(side=RIGHT) + + self.query_entry.bind("", lambda _: self.send_query()) + + results_frame = ttk.Frame(main_frame) + results_frame.pack(fill=BOTH, expand=True, padx=10, pady=(10, 5)) + + self.results_text = Text(results_frame, wrap=WORD, font=("TkDefaultFont", 10)) + y_scrollbar = ttk.Scrollbar(results_frame, orient=VERTICAL, command=self.results_text.yview) + self.results_text.configure(yscrollcommand=y_scrollbar.set) + + self.results_text.pack(side=LEFT, fill=BOTH, expand=True) + y_scrollbar.pack(side=RIGHT, fill=Y) + + self.results_text.config(state=NORMAL) + self.results_text.insert(END, "Enter your query below and click Send (or press Enter) to see results.\n") + self.results_text.config(state=DISABLED) + + return main_frame + + + def send_query(self): + query_text = self.query_entry.get().strip() + if not query_text: + return + + self.query_entry.delete(0, END) + + try: + response, _ = self._service.query( + lang=self.language_var.get(), + query=query_text, + ) + result_str = ''.join([f""" +---------------------- Result {idx} ---------------------- +SOURCE: {obj.properties['source']} +INSERTION DATE: {obj.properties['date']} +RELEVANT PROGRAMS: {', '.join(obj.properties['programs'])} + +CONTENT: +{obj.properties['body']} + +VECTOR: +{obj.vector} +""" for idx, obj in enumerate(response.objects, start=1)]) + + result_str = f"Query: {query_text}\n{result_str}" + + self.display_result(result_str) + except Exception as e: + self.display_result(f"Error:\n{str(e)}") + + + def display_result(self, result_text: str): + self.results_text.config(state=NORMAL) + self.results_text.delete(1.0, END) + self.results_text.insert(END, result_text + "\n") + self.results_text.config(state=DISABLED) + self.results_text.see(1.0) + + + def open_filters(self): + dialog = Toplevel(self._parent) + dialog.title("Query Filters") + dialog.geometry("400x300") + dialog.grab_set() diff --git a/src/apps/dbapp/utilclasses.py b/src/apps/dbapp/utilclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..1df23a5097c23314dc0f17b85b7f24da9d9b32a9 --- /dev/null +++ b/src/apps/dbapp/utilclasses.py @@ -0,0 +1,38 @@ +import os, json +from datetime import datetime +from src.config import config + +class BackupData: + def __init__(self, backup_id: str) -> None: + self._backup_id = backup_id + self._creation_date = "" + self._collections = [] + + backup_path = os.path.join(config.weaviate.BACKUP_PATH, backup_id) + files = os.listdir(backup_path) + + if 'data.json' in files: + data_path = os.path.join(backup_path, 'data.json') + with open(data_path) as f: + data = json.load(f) + + date = datetime.fromisoformat(data['creation_date']) + self._creation_date = date.strftime("%d.%m.%Y %H:%M:%S") + + if 'objects.json' in files: + objects_path = os.path.join(backup_path, 'objects.json') + with open(objects_path) as f: + data = json.load(f) + for name, objs in data.items(): + self._collections.append({ + 'name': name.lower(), + 'size': ('', len(objs)) + }) + + + def to_treeformat(self): + return { + 'id': self._backup_id.replace('backup_', ''), + 'date': (self._creation_date, ''), + 'collections': self._collections, + } diff --git a/src/cache/__init__.py b/src/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/cache/cache.py b/src/cache/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..19637485432b09988edcbd7f72cb6c0ecd6b4763 --- /dev/null +++ b/src/cache/cache.py @@ -0,0 +1,75 @@ +from threading import Lock +from src.cache.cache_metrics import CacheMetrics +from src.cache.cache_strategies import RedisCache, LocalCache + +from src.utils.logging import get_logger +from src.config import config + +logger = get_logger("cache ") + +class Cache: + _instance = None + _settings = None + _lock = Lock() + _cache_metrics = None + + @staticmethod + def configure(mode: str, cache: bool): + logger.info(f"Cache configured with parameters: mode={mode}, cache={cache}") + config.cache.ENABLED = cache + Cache._settings = { + "mode": mode, + "enabled": cache + } + + @staticmethod + def get_cache(): + if Cache._instance is not None: + return Cache._instance + + with Cache._lock: + if Cache._instance is not None: + return Cache._instance + + settings = Cache._settings or {"mode": 'local', "enabled": True} + + if not settings.get("enabled", True): + Cache._instance = None + return None + + if Cache._cache_metrics is None: + Cache._cache_metrics = CacheMetrics() + + mode = settings.get("mode", 'local') + + if mode == 'cloud': + cache_obj = RedisCache( + host=config.cache.CLOUD_HOST, + port=config.cache.CLOUD_PORT, + password=config.cache.CLOUD_PASS, + mode=mode, + metrics=Cache._cache_metrics + ) + elif mode == 'local': + cache_obj = RedisCache( + host=config.cache.LOCAL_HOST, + port=config.cache.LOCAL_PORT, + password=config.cache.LOCAL_PASS, + mode=mode, + metrics=Cache._cache_metrics + ) + elif mode == 'dict': + Cache._instance = LocalCache(metrics=Cache._cache_metrics) + return Cache._instance + else: + logger.error("FALLBACK to dict cache. Unknown cache mode") + Cache._instance = LocalCache(metrics=Cache._cache_metrics) + return Cache._instance + + if cache_obj.client is None: + logger.error("FALLBACK to dict cache. Redis connection failed") + Cache._instance = LocalCache(metrics=Cache._cache_metrics) + else: + Cache._instance = cache_obj + + return Cache._instance diff --git a/src/cache/cache_base.py b/src/cache/cache_base.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac6e4aa9efa3badc7aa2419bef0325d7eb73169 --- /dev/null +++ b/src/cache/cache_base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Any + +class CacheStrategy(ABC): + """ + Defines the interface for the different cache system strategies (Local or Redis). + """ + + @abstractmethod + def set(self, key: str, value: Any, language: str, session_id: str): + pass + + @abstractmethod + def get(self, key: str, language: str, session_id: str): + pass + + @abstractmethod + def clear_cache(self): + pass diff --git a/src/cache/cache_metrics.py b/src/cache/cache_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6d169f7e67a91a67d652e0e9d4e180b14c4cc8 --- /dev/null +++ b/src/cache/cache_metrics.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from threading import Lock + + +@dataclass +class CacheStatistics: + hits: int + misses: int + hits_ratio: float + +class CacheMetrics: + def __init__(self) -> None: + self.cache_stats = CacheStatistics(0, 0, 0.0) + self._lock = Lock() + + def increment_hit(self): + with self._lock: + self.cache_stats.hits += 1 + self._calc_hit_ratio() + + def increment_miss(self): + with self._lock: + self.cache_stats.misses += 1 + self._calc_hit_ratio() + + def _calc_hit_ratio(self): + total = self.cache_stats.hits + self.cache_stats.misses + self.cache_stats.hits_ratio = (self.cache_stats.hits / total) if total else 0.0 \ No newline at end of file diff --git a/src/cache/cache_strategies.py b/src/cache/cache_strategies.py new file mode 100644 index 0000000000000000000000000000000000000000..2d41e9651d5bab23bcb7c9f01c8c193cb4a10dc7 --- /dev/null +++ b/src/cache/cache_strategies.py @@ -0,0 +1,88 @@ +import json +from typing import Any +from cachetools import TTLCache + +from .utils import get_cache_key +from src.cache.cache_base import CacheStrategy +from src.database.redisservice import RedisService +from src.utils.logging import get_logger +from src.config import config + +logger = get_logger('cache_strat') + +class RedisCache(CacheStrategy): + def __init__(self, host, port, password, mode, metrics): + service = RedisService(host, port, password, mode) + self.client = service.get_client() + self.metrics = metrics + + + def set(self, key: str, value: Any, language: str, session_id: str): + if not self.client: return + + try: + json_str = json.dumps(value) + cache_key = get_cache_key(key, language, session_id) + self.client.set(cache_key, json_str, ex=config.cache.TTL_CACHE) + logger.info(f"Cached response with key {cache_key[:20]}... to Redis") + except Exception as e: + logger.error(f"Could not write to Redis: {e}") + + + def get(self, key: str, language: str, session_id: str): + if not self.client: return None + + try: + cache_key = get_cache_key(key, language, session_id) + val = self.client.get(cache_key) + if val is not None: + self.metrics.increment_hit() + logger.info(f"Found cached data with key {cache_key}") + logger.debug(f"Cache statistics: Hit cache {self.metrics.cache_stats.hits} times, ratio[{self.metrics.cache_stats.hits_ratio}]") + return json.loads(val) + + self.metrics.increment_miss() + logger.debug(f"Cache statistics: Missed cache {self.metrics.cache_stats.misses} times, ratio[{self.metrics.cache_stats.hits_ratio}]") + return None + except Exception as e: + logger.error(f"Could not read from Redis: {e}") + return None + + + def clear_cache(self): + if not self.client: return + + try: + self.client.flushdb() + logger.info(f"Redis Cache cleared.") + except Exception as e: + logger.error(f"Could not clear Redis cache: {e}") + + +class LocalCache(CacheStrategy): + def __init__(self, metrics): + self.cache = TTLCache(maxsize=config.cache.MAX_SIZE_CACHE, ttl=config.cache.TTL_CACHE) + self.metrics = metrics + + + def set(self, key: str, value: Any, language: str, session_id: str): + normalized_key = get_cache_key(key, language, session_id) + self.cache[normalized_key] = value + logger.info("Response cached") + + + def get(self, key: str, language: str, session_id: str): + normalized_key = get_cache_key(key, language, session_id) + res = self.cache.get(normalized_key, None) + if res is not None: + self.metrics.increment_hit() + logger.debug(f"Cache statistics: Hit cache {self.metrics.cache_stats.hits} times, ratio[{self.metrics.cache_stats.hits_ratio}]") + else: + self.metrics.increment_miss() + logger.debug(f"Cache statistics: Missed cache {self.metrics.cache_stats.misses} times, ratio[{self.metrics.cache_stats.hits_ratio}]") + return res + + + def clear_cache(self): + self.cache.clear() + logger.info("Local Cache cleared.") diff --git a/src/cache/utils.py b/src/cache/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0de02fe69592627d5deeb77b6eebc8296a1fcef --- /dev/null +++ b/src/cache/utils.py @@ -0,0 +1,5 @@ +import re + +def get_cache_key(key: str, language: str, session_id: str) -> str: + normalized_key = re.sub(r'[^a-z0-9]', '', key.lower()) + return f"cache:{session_id}:{language}:{normalized_key}" diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90aca7bd7b110cb458cb8afc82a5900390b6289e --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,39 @@ +from src.config.configs import * +from functools import lru_cache +from typing import Any +import config as c + +class AppConfig: + # ===================== INITIALIZE YOUR SUBCONFIGS HERE ===================== + + convstate: ConversationStateConfig = ConversationStateConfig() + processing: ProcessingConfig = ProcessingConfig() + weaviate: WeaviateConfig = WeaviateConfig() + scraping: ScrapingConfig = ScrapingConfig() + chain: ChainConfig = ChainConfig() + cache: CacheConfig = CacheConfig() + paths: PathsConfig = PathsConfig() + dbapp: DatabaseAppConfig = DatabaseAppConfig() + llm: LLMProviderConfig = LLMProviderConfig() + + # =========================================================================== + + def get(self, key: str, default: Any = None) -> Any: + """ + Retrieves an extra parameter from config.py by name. + + Raises: + AttributeError if not found and no default provided. + """ + try: + return getattr(c, key) + except AttributeError: + if default is not None: + return default + raise AttributeError(f"Config parameter '{key}' is not defined!") + +@lru_cache(maxsize=1) +def get_config() -> AppConfig: + return AppConfig() + +config = get_config() diff --git a/src/config/configs.py b/src/config/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..38be4bfd68b156dda90fefaa618b3200d5d51580 --- /dev/null +++ b/src/config/configs.py @@ -0,0 +1,249 @@ +from typing import Literal +from dotenv import load_dotenv + +import config, os + +load_dotenv() + +def _get(param: str, default=None, type_=None): + value = getattr(config, param, default) + + if value is None: + value = os.getenv(param) + + if value is None: + return default + + if not type_: return value + + try: + return type_(value) + except (ValueError, TypeError): + raise ValueError(f"Failed to cast '{param}' value '{value}' to {type_.__name__}") + + +class ConfigBase: + PARAMS: dict = dict() + + @classmethod + def __getitem__(cls, key): + return cls.PARAMS.get(key, None) + + @classmethod + def __setitem__(cls, key, value): + cls.PARAMS[key] = value + + +class DatabaseAppConfig(ConfigBase): + pass + + +class PathsConfig(ConfigBase): + DATA: str = _get('DATA_PATH') + LOGS: str = _get('LOGS_PATH') + URLS_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'urls') + CHUNKS_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'chunks') + TEMP_CHUNKS_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'temp_chunks') + SCRAPING_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'scraping') + RAW_TEXT_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'raw_text') + RAW_HTML_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'raw_html') + METADATA_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'metadata') + EXTRACTED_TEXT_OUTPUT: str = os.path.join(_get('DATA_PATH'), 'extracted_text') + + +class ScrapingConfig(ConfigBase): + TIMEOUT: int = _get('SCRAPING_SCRAPING_TIMEOUT', 30) + MAX_RETRIES: int = _get('SCRAPING_MAX_RETRIES', 3) + CRAWL_DELAY: int = _get('SCRAPING_CRAWL_DELAY', 1) + BACKOFF_RATE: int = _get('SCRAPING_BACKOFF_RATE', 2) + TARGET_URLS: int = _get('SCRAPING_TARGET_URLS', None) + INTERVALS: dict = _get('SCRAPING_PRIO_INTERVAL', dict()) + + +class ConversationStateConfig(ConfigBase): + TRACK_USER_PROFILE = _get('TRACK_USER_PROFILE') + LOCK_LANGUAGE_AFTER_N_MESSAGES = _get('LOCK_LANGUAGE_AFTER_N_MESSAGES') + MAX_CONVERSATION_TURNS = _get('MAX_CONVERSATION_TURNS') + + +class ProcessingConfig(ConfigBase): + LANG_AMBIGUITY_THRESHOLD: float = _get('LANG_AMBIGUITY_THRESHOLD') + EMBEDDING_MODEL: float = _get('EMBEDDING_MODEL') + MAX_TOKENS: int = _get('MAX_TOKENS') + CHUNK_OVERLAP: int = _get('CHUNK_OVERLAP') + + +class ChainConfig(ConfigBase): + ENABLE_RESPONSE_CHUNKING: bool = _get('ENABLE_RESPONSE_CHUNKING', True) + EVALUATE_RESPONSE_QUALITY: bool = _get('ENABLE_EVALUATE_RESPONSE_QUALITY', True) + CONFIDENCE_THRESHOLD: float = _get('CONFIDENCE_THRESHOLD') + + TOP_K_RETRIEVAL: int = _get('TOP_K_RETRIEVAL', 4) + MAX_RETRIES: int = _get('MODEL_MAX_RETRIES', 3) + MAX_RESPONSE_WORDS_LEAD: int = _get('MAX_RESPONSE_WORDS_LEAD', 100) + MAX_RESPONSE_WORDS_SUBAGENT: int = _get('MAX_RESPONSE_WORDS_SUBAGENT', 200) + + +class CacheConfig(ConfigBase): + ENABLED: bool = _get('CACHE_ENABLED', False) + CACHE_MODE: Literal['local', 'cloud', 'dict'] = _get('CACHE_MODE') + + LOCAL_HOST: str = _get('CACHE_LOCAL_HOST', 'localhost') + LOCAL_PORT: int = _get('CACHE_LOCAL_PORT', 6379) + LOCAL_PASS: str = _get('CACHE_LOCAL_PASSWORD', '') + + CLOUD_HOST: str = _get('REDIS_CLOUD_HOST') + CLOUD_PORT: int = _get('REDIS_CLOUD_PORT', type_=int) + CLOUD_PASS: str = _get('REDIS_CLOUD_PASSWORD') + + TTL_CACHE: int = _get('CACHE_TTL', 86400) + MAX_SIZE_CACHE: int = _get('CACHE_MAX_SIZE', 1000) + + +class WeaviateConfig(ConfigBase): + LOCAL_DATABASE: bool = _get('WEAVIATE_IS_LOCAL') + WEAVIATE_COLLECTION_BASENAME: str = _get('WEAVIATE_COLLECTION_BASENAME') + + BACKUP_METHODS: list[str] = ['manual', 'filesystem', 's3'] + BACKUP_METHOD: Literal['manual', 'filesystem', 's3'] = _get('WEAVIATE_BACKUP_METHOD') + + BACKUP_PATH: str = _get('BACKUPS_PATH') + PROPERTIES_PATH: str = _get('PROPERTIES_PATH') + STRATEGIES_PATH: str = _get('STRATEGIES_PATH') + + CLUSTER_URL: str = _get('WEAVIATE_CLUSTER_URL') + WEAVIATE_API_KEY: str = _get('WEAVIATE_API_KEY') + HUGGING_FACE_API_KEY: str = _get('HUGGING_FACE_API_KEY') + + INIT_TIMEOUT: int = _get('WEAVIATE_INIT_TIMEOUT', 90) + QUERY_TIMEOUT: int = _get('WEAVIATE_QUERY_TIMEOUT', 60) + INSERT_TIMEOUT: int = _get('WEAVIATE_INSERT_TIMEOUT', 600) + + +#TODO: Clean this configuration (outdated) +class LLMProvider: + def __init__(self, base: str, sub: str | None = None) -> None: + self.base = base + self.sub = sub + self.name = f"{base}:{sub}" if sub else base + + + def with_sub(self, sub: str | None = None) -> str: + return LLMProvider(self.base, sub) + + +class LLMProviderConfig: + AVAIABLE_PROVIDERS: list[str] = [ + 'groq', + 'ollama', + 'openai', + 'open_router', + ] + AVAILABLE_SUBPROVIDERS: dict = { + 'groq': [], + 'open_router': [ + 'openai', + 'deepseek', + 'meituan' + 'alibaba' # For tongyi models + 'nvidia', + ], + } + + LLM_PROVIDER: LLMProvider = LLMProvider('openai') + + # -------------------- Some predefined models for available providers ---------------------- + + # Groq settings + GROQ_API_KEY: str = os.getenv("GROQ_API_KEY") + GROQ_MODEL: str = "mixtral-8x7b-32768" + + # Open Router settings + OPEN_ROUTER_API_KEY: str = os.getenv("OPEN_ROUTER_API_KEY") + OPEN_ROUTER_MODEL: str = "meituan/longcat-flash-chat:free" + OPEN_ROUTER_BASE_URL: str = "https://openrouter.ai/api/v1" + + # OpenAI settings + OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY") + OPENAI_MODEL: str = "gpt-5.1" + + # The gpt-oss:20b model is preferable but takes much more space + # Set to False if you only have the llama3.2 installed + GPT_OSS_ENABLED: bool = False + # Local/Ollama settings + OLLAMA_BASE_URL: str = "http://localhost:11434" + OLLAMA_MODEL: str = "gpt-oss:20b" if GPT_OSS_ENABLED else "llama3.2" + + # ---------------------------------------------------------------------------------------- + + @classmethod + def get_fallback_models(cls, provider: LLMProvider | None = None) -> list[str]: + provider = provider or cls.LLM_PROVIDER + match provider.base: + case 'openai': + return { + provider: fallback_model + for fallback_model in [ + 'gpt-5-mini', + 'gpt-5-nano', + ] + } + case 'open_router': + return { + provider.with_sub('openai'): "gpt-oss-20b", + provider.with_sub('openai'): "gpt-oss-120b", + provider.with_sub('alibaba'): "alibaba/tongyi-deepresearch-30b-a3b:free", + provider: "openrouter/polaris-alpha", + # Currently unusable because has no tool support + #provider.with_sub('deepseek'): "deepseek/deepseek-chat-v3.1:free", + } + case _: + return {} + + @classmethod + def get_reasoning_support(cls, provider: LLMProvider | None = None) -> bool: + provider = provider or cls.LLM_PROVIDER + return { + "groq": True, + "openai": True, + "open_router": True, + }.get(provider.base, False) + + + @classmethod + def get_default_model(cls, provider: LLMProvider | None = None) -> str: + provider = provider or cls.LLM_PROVIDER + return { + "groq": cls.GROQ_MODEL, + "openai": cls.OPENAI_MODEL, + "ollama": cls.OLLAMA_MODEL, + "open_router": cls.OPEN_ROUTER_MODEL, + }.get(provider.base) + + + @classmethod + def get_api_key(cls, provider: LLMProvider | None = None) -> str: + provider = provider or cls.LLM_PROVIDER + return { + "groq": cls.GROQ_API_KEY, + "openai": cls.OPENAI_API_KEY, + "open_router": cls.OPEN_ROUTER_API_KEY, + }.get(provider.base) + + +class NotificationCenterConfig(ConfigBase): + ENABLE_EMAIL_ALERTS: bool = _get('NOTIFY_ENABLE_EMAIL_ALERTS', True, bool) + + SMTP_HOST: str = _get("NOTIFY_SMTP_HOST") + SMTP_PORT: int = _get("NOTIFY_SMTP_PORT", 587, type_=int) + + SMTP_USER: str = _get("NOTIFY_SMTP_USER") + SMTP_PASSWORD: str = _get("NOTIFY_SMTP_PASSWORD") + + SMTP_USE_TLS: bool = _get("NOTIFY_SMTP_USE_TLS", "True").lower() in ("1", "true", "yes", "on") + + FROM_EMAIL: str = _get("NOTIFY_FROM_EMAIL") + TO_EMAIL: str = _get("NOTIFY_TO_EMAIL") + + ENABLE_SLACK_ALERTS: bool = _get('NOTIFY_ENABLE_SLACK_ALERTS', False, bool) + SLACK_WEBHOOK_URL: str = _get("NOTIFY_SLACK_WEBHOOK_URL") diff --git a/src/const/agent_response_constants.py b/src/const/agent_response_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..eddac52ae6c7ef79b5ed8948b06c6efd10f5598f --- /dev/null +++ b/src/const/agent_response_constants.py @@ -0,0 +1,209 @@ +""" Constants for Gradio app """ + +GREETING_MESSAGES = { + "en": [ + "Hello and welcome. I am your Executive Education Advisor for the HSG Executive MBA programmes (**IEMBA**, **emba X**, and **EMBA**). How may I support your MBA planning today?", + "Hello and welcome. I am your Executive Education Advisor for the University of St.Gallen Executive MBA programmes (**IEMBA**, **emba X**, and **EMBA**). How may I assist you with your programme search?", + "Hello and welcome. I am here to help you explore the University of St.Gallen Executive MBA programmes (**EMBA**, **IEMBA**, and **emba X**). What would you like to discuss today?", + "Hello and welcome. I am your Executive Education Advisor for the University of St.Gallen’s Executive MBA programmes, and I am here to help you assess fit across **EMBA**, **IEMBA**, and **emba X**.", + "Hello and welcome. I am here to support you with questions about the University of St.Gallen Executive MBA programmes and to help you evaluate the **EMBA**, **IEMBA**, and **emba X** options.", + ], + "de": [ + "Guten Tag. Ich bin Ihr Executive-Education-Berater für die HSG Executive MBA Programme und unterstütze Sie gerne bei Fragen zu **EMBA**, **IEMBA** und **emba X**.", + "Guten Tag. Ich bin Ihr Executive-Education-Berater für die HSG Executive MBA Programme (**EMBA**, **IEMBA**, **emba X**). Ich unterstütze Sie bei Programmwahl, Ablauf und Zulassungsfragen.", + "Guten Tag und herzlich willkommen. Ich bin Ihr Executive-Education-Berater für die HSG Executive MBA Programme und unterstütze Sie gerne bei Fragen zu **EMBA**, **IEMBA** und **emba X**.", + "Guten Tag. Ich bin Ihr Executive-Education-Berater für die HSG Executive MBA Programme (**EMBA**, **IEMBA**, **emba X**) und unterstütze Sie gerne bei der Einschätzung der passenden Option.", + "Guten Tag. Ich unterstütze Sie gerne bei Fragen zu den HSG Executive MBA Programmen und helfe Ihnen, die Optionen **EMBA**, **IEMBA** und **emba X** einzuordnen.", + ] +} + +QUERY_EXCEPTION_MESSAGE = { + "en": "I'm sorry, I cannot provide a helpful response right now. Please contact tech support or try again later.", + "de": "Es tut mir leid, ich kann im Moment keine hilfreiche Antwort geben. Bitte wenden Sie sich an den technischen Support oder versuchen Sie es später erneut.", +} + +NOT_VALID_QUERY_MESSAGE = { + "en": "I didn't quite understand that. Could you please rephrase your question?", + "de": "Das habe ich nicht ganz verstanden. Könnten Sie Ihre Frage bitte anders formulieren?", +} + +CONFIDENCE_FALLBACK_MESSAGE = { + "en": ( + "I am sorry, but I could not find sufficiently reliable information in my records to answer that question with confidence. " + "Could you please rephrase your question?\n\n" + "If you would like a personal consultation, I can also help you with appointment booking." + ), + "de": ( + "Es tut mir leid, aber ich konnte in meinen Unterlagen keine Informationen finden, " + "die zu Ihrer Anfrage passen, sodass ich sie nicht mit ausreichender Sicherheit beantworten kann. " + "Könnten Sie Ihre Frage bitte umformulieren?\n\n" + "Wenn Sie ein persönliches Beratungsgespräch wünschen, kann ich Ihnen auch bei der Terminbuchung helfen." + ), +} + +LANGUAGE_FALLBACK_MESSAGE = { + "en": ( + "I am sorry, I can only reply in English or German. " + "Would you like to continue our conversation in English?" + ), + "de": ( + "Es tut mir leid, ich kann nur auf Englisch oder Deutsch antworten. " + "Möchten Sie unser Gespräch auf Deutsch fortführen?" + ), +} + +CONVERSATION_END_MESSAGE = { + "en": ( + "This conversation has reached its maximum length. " + "To make sure you receive the best possible support, " + "please continue with a personal consultation.\n\n" + "If you would like to see appointment options with an admissions advisor, please ask me to show them. " + "Thank you for your understanding." + ), + "de": ( + "Dieses Gespräch hat die maximale Länge erreicht. " + "Damit Sie bestmöglich unterstützt werden, bitten wir Sie, " + "das Anliegen in einem persönlichen Beratungsgespräch fortzusetzen.\n\n" + "Wenn Sie Terminoptionen mit der Studienberatung sehen möchten, sagen Sie mir bitte kurz Bescheid. " + "Vielen Dank für Ihr Verständnis." + ), +} + +ADMISSIONS_TEAM_CONTACT = { + "en": { + "email": "emba@unisg.ch", + "phone": "+41 71 224 27 02", + }, + "de": { + "email": "emba@unisg.ch", + "phone": "+41 71 224 27 02", + }, +} + +ADVISOR_CONTACTS = [ + { + "name": "Cyra von Müller (EMBA)", + "program": "emba", + "email": "cyra.vonmueller@unisg.ch", + "phone": "+41 71 224 27 12", + "url": "https://calendly.com/cyra-vonmueller/beratungsgespraech-emba-hsg", + }, + { + "name": "Kristin Fuchs (IEMBA)", + "program": "iemba", + "email": "kristin.fuchs@unisg.ch", + "phone": "+41 71 224 75 46", + "url": "https://calendly.com/kristin-fuchs-unisg/iemba-online-personal-consultation", + }, + { + "name": "Teyuna Giger (emba X)", + "program": "emba_x", + "email": "teyuna.giger@unisg.ch", + "phone": "+41 71 224 77 65", + "url": "https://calendly.com/teyuna-giger-unisg", + }, +] + + +def get_admissions_contact_text(language: str = "en") -> str: + labels = { + "en": "You can reach the Executive MBA admissions team at {email} or {phone}.", + "de": "Sie erreichen das Executive-MBA-Zulassungsteam unter {email} oder {phone}.", + } + contact = ADMISSIONS_TEAM_CONTACT.get(language, ADMISSIONS_TEAM_CONTACT["en"]) + template = labels.get(language, labels["en"]) + return template.format(email=contact["email"], phone=contact["phone"]) + + +def get_booking_widget(language: str="en", programs: list[str]=None): + """ + Returns an HTML string representing a Booking Widget. + """ + + if programs is None or programs == []: + programs = ["emba", "iemba", "emba_x"] + + labels = { + "en": { + "header": "Book a Consultation", + "sub": "Select an advisor to view available appointment slots and contact details:", + "email": "Email", + "phone": "Phone", + }, + "de": { + "header": "Termin vereinbaren", + "sub": "Wählen Sie einen Berater, um verfügbare Termine und Kontaktdaten zu sehen:", + "email": "E-Mail", + "phone": "Telefon", + } + } + txt = labels.get(language, labels["en"]) + + base_params = "?hide_gdpr_banner=1&embed_type=Inline&embed_domain=1" + + html_content = f""" +
+

{txt['header']}

+

{txt['sub']}

+ """ + + for advisor in ADVISOR_CONTACTS: + if advisor["program"] in programs: + html_content += f""" +
+ + {advisor['name']} + +
+

{txt['email']}: {advisor['email']}

+

{txt['phone']}: {advisor['phone']}

+
+
+ +
+
+ """ + + html_content += "
" + return html_content + + +def get_disclaimer_widget(language: str = "en"): + """ + Returns an HTML string representing a warning disclaimer. + """ + disclaimers = { + "en": { + "title": "Disclaimer", + "body": "Assessments provided by this advisor are non-binding and based on limited information. Please consult our program directors for final admission or credit evaluations." + }, + "de": { + "title": "Haftungsausschluss", + "body": "Die Einschätzungen dieses Beraters sind unverbindlich und basieren auf begrenzten Informationen. Bitte wenden Sie sich für endgültige Zulassungs- oder Anrechnungsfragen an die Programmleitung." + } + } + + content = disclaimers.get(language, disclaimers["en"]) + + # Yellow styling constants + bg_color = "#fffbeb" # Light yellow + border_color = "#f59e0b" # Amber/Yellow border + icon_color = "#d97706" # Darker amber for the icon + text_color = "#92400e" # Dark brown/yellow for readability + + html_content = f""" +
+
+ + + +
+
+ {content['title']} +

+ {content['body']} +

+
+
+ """ + return html_content diff --git a/src/const/cc_whitelist.py b/src/const/cc_whitelist.py new file mode 100644 index 0000000000000000000000000000000000000000..e48ab6117969aa4d0b1aaf4885481143ae247382 --- /dev/null +++ b/src/const/cc_whitelist.py @@ -0,0 +1,3 @@ +REPETITION_WHITELIST = [ + 'january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december', 'januar', 'februar', 'märz', 'mai', 'juni', 'juli', 'oktober', 'dezember', 'total', 'iemba', 'emba', 'emba x', 'programme', 'program', +] diff --git a/src/const/data_consent_constants.py b/src/const/data_consent_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9dfd7b28473fb787ebad44693d2c7e575d339b22 --- /dev/null +++ b/src/const/data_consent_constants.py @@ -0,0 +1,60 @@ +PRIVACY_NOTICE = { + "de": """ +### Datenschutzhinweis + +Wir verwenden Ihre Angaben, um Sie zu **Executive MBA Programmen der Universität St.Gallen** zu beraten. +Dabei verarbeiten wir insbesondere: + +- Ihre Gesprächsinhalte und Anfragen +- Kontaktdaten (Name, E-Mail) bei Terminbuchung +- Informationen zu Ihrer Berufserfahrung und Ausbildung + +Ihre Daten werden **ausschließlich für die Studienberatung** verwendet und **nicht an Dritte weitergegeben**. +Sie können Ihre Einwilligung **jederzeit widerrufen**. + +[Weitere Informationen zur Datenschutzerklärung](https://www.unisg.ch/en/data-protection-declaration/) +""", + + "en": """ +### Privacy Notice + +We use your information to advise you on **Executive MBA programmes at the University of St.Gallen**. +We process in particular: + +- Your conversation content and inquiries +- Contact details (name, email) for appointment booking +- Information about your professional experience and education + +Your data is used **solely for study advisory purposes** and **is not shared with third parties**. +You may **withdraw your consent at any time**. + +[More information in the Privacy Policy](https://www.unisg.ch/en/data-protection-declaration/) +""" +} + +ACCEPT = { + "de": "Zustimmen", + "en": "Accept" +} + +DECLINE = { + "de": "Ablehnen", + "en": "Decline" +} + +DECLINE_MESSAGE = { + "de": "Ohne Ihre Einwilligung können wir Sie leider nicht beraten. " + "Bitte kontaktieren Sie uns direkt unter emba@unisg.ch.", + "en": "Without your consent, we cannot provide advice. " + "Please contact us directly at emba@unisg.ch.", +} + +WITHDRAW_CONFIRMATION_MESSAGE = { + "de": "Ihre Einwilligung wurde widerrufen. Ihre Session-Daten wurden gelöscht. Ohne Einwilligung können wir Sie leider nicht beraten.", + "en": "Your consent has been withdrawn. Your session data has been deleted. Without consent, we cannot continue advising you." +} + +WITHDRAW_TEXT = { + "de": "Einwilligung widerrufen", + "en": "Withdraw Consent" +} \ No newline at end of file diff --git a/src/const/page_blacklist.py b/src/const/page_blacklist.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0f6b377f0b9c3421c18338d28381f0603fdb0d --- /dev/null +++ b/src/const/page_blacklist.py @@ -0,0 +1,5 @@ +PAGE_BLACKLIST = [ + 'cookie', 'cookies', 'privacy', 'datenschutz', 'popup', 'download', + 'cookie-policy', 'privacy-policy', 'cookie-and-privacy-policy', + 'data-protection', 'impressum', 'legal', 'terms', 'agb', 'imprint' +] \ No newline at end of file diff --git a/src/const/page_priority.py b/src/const/page_priority.py new file mode 100644 index 0000000000000000000000000000000000000000..8125ac65cf0bf61b7703eb940095952d2733f20b --- /dev/null +++ b/src/const/page_priority.py @@ -0,0 +1,137 @@ +PAGE_PRIORITY_KEYWORDS = { + 'high': [ + # -------------------------------------------- EN -------------------------------------------- + 'overview', 'about', 'introduction', 'summary', 'home', 'general information', 'welcome', + 'admissions', 'admission', 'apply', 'application', 'how to apply', 'enrollment', 'prospective students', 'entrance', + 'costs', 'tuition', 'fees', 'expenses', 'financial information', 'funding', 'scholarships', + 'curriculum', 'courses', 'program', 'programmes', 'degree structure', 'modules', 'syllabus', + 'eligibility', 'admission requirements', 'entry requirements', 'qualifications', 'prerequisites', 'criteria', + 'deadlines', 'application deadlines', 'key dates', 'timeline', 'due dates', 'important dates' + + # -------------------------------------------- DE -------------------------------------------- + 'übersicht', 'überblick', 'einführung', 'zusammenfassung', 'allgemeines', 'willkommen', + 'zulassung', 'zulassungen', 'bewerbung', 'bewerbungen', 'wie bewerben', 'einschreibung', 'potenzielle studenten', 'aufnahme', + 'kosten', 'studiengebühren', 'gebühren', 'ausgaben', 'finanzielle informationen', 'finanzierung', 'stipendien', + 'studienplan', 'lehrplan', 'curriculum', 'modulhandbuch', 'studiengangsstruktur', 'module', 'lehrstoff', + 'voraussetzungen', 'zulassungsvoraussetzungen', 'eintrittsvoraussetzungen', 'qualifikationen', 'vorkenntnisse', 'kriterien', + 'fristen', 'bewerbungsfristen', 'schlüsseltermine', 'zeitplan', 'fälligkeitsdaten', 'wichtige daten' + ], + 'medium': [ + # -------------------------------------------- EN -------------------------------------------- + 'faculty', 'faculties', 'staff', 'professors', 'departments', 'team', 'instructors', 'lecturers', + 'alumni', 'graduates', 'former students', 'success stories', 'alumnae' + + # -------------------------------------------- DE -------------------------------------------- + 'fakultät', 'fakultäten', 'personal', 'professoren', 'dozenten', 'abteilungen', 'team', 'lehrkräfte', + 'alumni', 'absolventen', 'ehemalige studenten', 'erfolgsgeschichten' + ], + 'low': [ + # -------------------------------------------- EN -------------------------------------------- + 'news', 'press', 'blog', 'updates', 'articles', 'announcements', + 'events', 'calendar', 'activities', 'conferences', 'workshops', 'seminars' + + # -------------------------------------------- DE -------------------------------------------- + 'nachrichten', 'presse', 'blog', 'aktualisierungen', 'artikel', 'ankündigungen', + 'veranstaltungen', 'kalender', 'aktivitäten', 'konferenzen', 'workshops', 'seminare' + ], +} + +CHUNK_TOPIC_KEYWORDS = { + 'admissions': { + # ----------------------- EN ----------------------- + 'admissions', 'application', 'apply', 'application process', + 'deadline', 'deadlines', 'selection', 'assessment', + 'interview', 'admissions committee', 'application form', + 'submit', 'submission', 'enrollment', + + # ----------------------- DE ----------------------- + 'zulassung', 'bewerbung', 'bewerben', + 'bewerbungsprozess', 'frist', 'fristen', + 'auswahlverfahren', 'aufnahmeverfahren', + 'assessment', 'interview', 'aufnahmegespräch', + 'zulassungskomitee', 'einschreibung', 'immatrikulation', + 'einreichen' + }, + + 'costs': { + # ----------------------- EN ----------------------- + 'tuition', 'tuition fee', 'fees', 'costs', 'expenses', + 'payment', 'payment plan', 'installment', 'installments', + 'deposit', 'price', 'total cost', + 'funding', 'financing', 'loan', 'loans', + 'scholarship', 'scholarships', 'budget', + + # ----------------------- DE ----------------------- + 'studiengebühren', 'gebühren', 'kosten', 'ausgaben', + 'zahlung', 'zahlungsplan', 'rate', 'raten', + 'anzahlung', 'preis', 'gesamtkosten', + 'finanzierung', 'kredit', 'kredite', + 'stipendium', 'stipendien', 'budget' + }, + + 'curriculum': { + # ----------------------- EN ----------------------- + 'curriculum', 'program', 'programme', 'content', + 'module', 'modules', 'course', 'courses', + 'structure', 'format', 'timeline', 'schedule', + 'duration', 'ects', 'credits', + 'training', 'coaching', 'workshop', 'workshops', + 'project', 'projects', 'leadership', 'development', + 'learning', 'electives', + + # ----------------------- DE ----------------------- + 'curriculum', 'programm', 'studium', 'inhalt', + 'modul', 'module', 'kurs', 'kurse', + 'struktur', 'format', 'zeitplan', 'ablauf', + 'dauer', 'ects', 'leistungspunkte', + 'training', 'coaching', 'workshop', 'workshops', + 'projekt', 'projekte', 'führung', 'entwicklung', + 'lernen', 'wahlfächer' + }, + + 'eligibility': { + # ----------------------- EN ----------------------- + 'eligibility', 'requirements', 'prerequisites', + 'admission requirements', 'criteria', + 'qualification', 'qualifications', + 'work experience', 'leadership experience', + 'degree', 'academic degree', + 'language requirement', 'fluency', + + # ----------------------- DE ----------------------- + 'voraussetzungen', 'zulassungsvoraussetzungen', + 'anforderungen', 'kriterien', + 'qualifikation', 'qualifikationen', + 'berufserfahrung', 'führungserfahrung', + 'abschluss', 'studienabschluss', + 'sprachkenntnisse', 'sprachvoraussetzungen' + }, + + 'alumni': { + # ----------------------- EN ----------------------- + 'alumni', 'alumni network', + 'graduates', 'community', + 'career service', 'mentoring', + + # ----------------------- DE ----------------------- + 'alumni', 'alumni-netzwerk', + 'absolventen', 'gemeinschaft', + 'karriereservice', 'mentoring' + }, + + 'general': { + # ----------------------- EN ----------------------- + 'overview', 'introduction', 'summary', + 'highlights', 'benefits', 'advantages', + 'experience', 'journey', + 'programme details', 'program details', + 'location', 'format', 'language', + + # ----------------------- DE ----------------------- + 'überblick', 'einführung', 'zusammenfassung', + 'highlights', 'vorteile', + 'erfahrung', 'reise', + 'programmdetails', 'standort', + 'format', 'sprche' + }, +} diff --git a/src/database/__init__.py b/src/database/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/database/docker-compose-cache.yml b/src/database/docker-compose-cache.yml new file mode 100644 index 0000000000000000000000000000000000000000..4f563367e152da68e26ac7ebd3ecb1a254328b82 --- /dev/null +++ b/src/database/docker-compose-cache.yml @@ -0,0 +1,27 @@ +version: '3.8' + +services: + redis: + image: redis:alpine + container_name: hsg_redis_cache + ports: + - "6379:6379" + command: > + redis-server + --requirepass "${REDIS_PASSWORD}" + --save 60 1 + --loglevel warning + --maxmemory 200mb + --maxmemory-policy allkeys-lru + volumes: + - redis_data:/data + restart: unless-stopped + + healthcheck: + test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"] + interval: 5s + timeout: 3s + retries: 5 + +volumes: + redis_data: \ No newline at end of file diff --git a/src/database/docker-compose.yml b/src/database/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..8404d389795ea515dc3bf4bf28fd5debb316e750 --- /dev/null +++ b/src/database/docker-compose.yml @@ -0,0 +1,29 @@ +version: '3.4' + +services: + weaviate: + image: semitechnologies/weaviate:1.33.0 + restart: on-failure:0 + ports: + - "8080:8080" + - "50051:50051" + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + ENABLE_API_BASED_MODULES: 'true' + ENABLE_MODULES: 'text2vec-transformers' + TRANSFORMERS_INFERENCE_API: 'http://t2v-transformers:8080' + CLUSTER_HOSTNAME: 'node1' + volumes: + - weaviate_data:/var/lib/weaviate + + t2v-transformers: + image: semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2 + restart: on-failure:0 + ports: + - "8081:8080" + +volumes: + weaviate_data: + diff --git a/src/database/redisservice.py b/src/database/redisservice.py new file mode 100644 index 0000000000000000000000000000000000000000..ca281921c354e85313c47aedc3130f9a8e89800d --- /dev/null +++ b/src/database/redisservice.py @@ -0,0 +1,53 @@ +import redis +from threading import Lock +from src.utils.logging import get_logger + +logger = get_logger("redis_service") + +class RedisService: + _instance = None + _init_lock = Lock() + + def __new__(cls, host, port, password, mode): + if cls._instance is None: + with cls._init_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, host, port, password, mode): + if hasattr(self, '_initialized') and self._initialized: + return + + self._client = None + self._host = host + self._port = port + self._password = password + self.mode = mode + + self._connect() + + self._initialized = True + + def _connect(self): + try: + logger.info(f"Connecting to Redis at {self._host}:{self._port}...") + self._client = redis.Redis( + host=self._host, + port=self._port, + password=self._password, + decode_responses=True, + socket_connect_timeout=2, + socket_timeout=2 + ) + self._client.ping() + logger.info(f"Successfully connected to Redis! {self.mode}") + except Exception as e: + logger.error(f"Redis connection failed: {e}") + self._client = None + + def get_client(self): + return self._client + + def is_connected(self) -> bool: + return self._client is not None diff --git a/src/database/weavservice.py b/src/database/weavservice.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6c1367819bf52b64fc79270baebd9c7cc3e042 --- /dev/null +++ b/src/database/weavservice.py @@ -0,0 +1,851 @@ +from functools import reduce +import weaviate as wvt +import datetime, os +from threading import Lock + +from time import perf_counter, sleep +from weaviate.classes.config import Configure, Property, DataType +from weaviate.collections.classes.grpc import MetadataQuery +from weaviate.collections.collection import Collection +from weaviate.classes.init import AdditionalConfig, Timeout +from weaviate.classes.query import Filter +from weaviate.config import AdditionalConfig + +from ..utils.logging import get_logger +from ..config import config + +logger = get_logger("weaviate_service") + +_get_collection_name = lambda lang: f'{config.weaviate.WEAVIATE_COLLECTION_BASENAME}_{lang}' +_collection_names = [_get_collection_name(lang) for lang in config.get('AVAILABLE_LANGUAGES')] + + +def _default_properties() -> list[Property]: + return [ + Property(name='body', data_type=DataType.TEXT), + Property(name='chunk_id', data_type=DataType.TEXT), + Property(name='document_id', data_type=DataType.TEXT), + Property(name='programs', data_type=DataType.TEXT_ARRAY), + Property(name='source', data_type=DataType.TEXT), + Property(name='date', data_type=DataType.DATE), + ] + + +class WeaviateService: + """ + Provides an interface for interacting with the Weaviate vector database. + Handles initialization, data import, and hybrid queries. + """ + + _instance = None + _init_lock = Lock() + + def __new__(cls): + if cls._instance is None: + with cls._init_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + """ + Initialize the Weaviate service. + """ + if hasattr(self, '_initialized'): + return + + self._connection_type = 'local' if config.weaviate.LOCAL_DATABASE else 'cloud' + self._client = None + self._client_lock = Lock() + + # Some parameters to ensure that the connection will not be closed + # during long pauses in conversations + self._last_query_time = perf_counter() + self._idle_timeout = 25 * 60 + self._initialized = True + + # Initialize the client for the first time + logger.info("Initializing Weaviate service...") + try: + self._init_client() + logger.info("Weaviate service initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize Weaviate service: {e}") + raise e + + + def _init_client(self) -> wvt.WeaviateClient: + """ + Initializes the weaviate client with additional configuration. + Performs a warm-up querying to speed-up the subsequent calls. + + Returns: + configured Weaviate client instance on successfull connection. + + Raises: + WeaviateConnectionError of the last failed connection if connection fails after 3 retires. + """ + # Returns the client if it hasn't been idling for too long + if self._client is not None: + time_since_query = perf_counter() - self._last_query_time + if time_since_query < self._idle_timeout: + return self._client + + # The connection might be closed, clients has to be reconnected + logger.warning(f"Client has been idling for too long. Reconnecting to prevent server-side closure...") + try: + self._client.close() + except Exception as _: + pass + + self._client = None + + # Client initialization + with self._client_lock: + if self._client: + return self._client + + retries = 0 + last_exception: Exception = None + while retries < 3: + try: + if config.weaviate.LOCAL_DATABASE: + self._client = wvt.connect_to_local() + break + + self._client = wvt.connect_to_weaviate_cloud( + cluster_url=config.weaviate.CLUSTER_URL, + auth_credentials=config.weaviate.WEAVIATE_API_KEY, + additional_config=AdditionalConfig( + timeout=Timeout( + init=config.weaviate.INIT_TIMEOUT, + query=config.weaviate.QUERY_TIMEOUT, + insert=config.weaviate.INSERT_TIMEOUT, + ), + skip_init_checks=False, + ), + headers={ + "X-HuggingFace-Api-Key": config.weaviate.HUGGING_FACE_API_KEY, + }, + ) + + # Warm-up query + logger.info("Running warm-up query to initialize server...") + try: + collection = _get_collection_name(config.get('AVAILABLE_LANGUAGES')[0]) + self._client.collections.exists(collection) + logger.info("Warm-up finished - server is ready!") + except Exception as warmup_err: + logger.warning(f"Warm-up query failed (non-critical): {warmup_err}") + + break + except Exception as e: + last_exception = e + logger.warning(f"Failed to establish connection on try {retries}: {e}") + retries += 1 + sleep(1) + + if retries == 3: + logger.error(f"Failed to establish connection after 3 retries!") + raise last_exception + + logger.info(f"Successully connected to the {self._connection_type} weaviate database") + self._last_query_time = perf_counter() + return self._client + + + def _select_collection(self, lang: str) -> tuple[Collection, str]: + """ + Select a language-specific collection as the active working collection. + + Args: + lang (str): Acceptable language code. + + Raises: + weaviate.exceptions.WeaviateConnectionError: If the specified language collection does not exist. + """ + if lang not in config.get('AVAILABLE_LANGUAGES'): + logger.error(f"No collection for language '{lang}' was found in the database") + return None, '' + + collection_name = _get_collection_name(lang) + logger.debug(f"Using collection {collection_name}") + + client = self._init_client() + return client.collections.use(collection_name), collection_name + + + def batch_import(self, data_rows: list, lang: str) -> list: + """ + Perform a batch import of multiple objects into the current collection. + + Args: + data_rows (list): List of dictionaries representing the data rows to import. + lang (str, optional): Language collection to use. If not provided, uses the current one. + + Returns: + list[dict]: List of failed imports with error details, if any. + + Raises: + If no active collection is available or a connection error was catched. + """ + collection, collection_name = self._select_collection(lang) + if collection is None: + logger.error("No working collection selected!") + return [] + + import_errors = [] + logger.info(f"Batch importing {len(data_rows)} rows into {collection_name}") + + try: + with self._client_lock: + with collection.batch.fixed_size(batch_size=100, concurrent_requests=2) as batch: + for idx, data_row in enumerate(data_rows): + try: + batch.add_object(properties=data_row) + except Exception as e: + import_errors.append({'index': idx, 'chunk_id': data_row['chunk_id'], 'error': str(e)}) + + if idx % 20 == 0 and idx > 0: + if batch.number_errors > 0: + logger.info(f"Failed imports at index {idx}: {batch.number_errors}") + + self._last_query_time = perf_counter() + logger.info(f"Batch import finished. Total errors: {len(import_errors)}") + + except Exception as e: + if 'connection' in str(e).lower(): + logger.error(f"Connection error during batch import: {e}") + self._client = None + raise e + + return import_errors + + + @staticmethod + def _create_property_filter(prop, values) -> Filter: + match prop: + case 'programs': + return Filter.by_property('programs').contains_any(values) + case 'source': + return Filter.by_property('source').contains_any(values) \ + if isinstance(values, list) else Filter.by_property('source').equal(values) + case _: + return None + + + def delete_chunks(self, lang: str, property_filters: dict[str, any] = None) -> int: + """ + Delete all chunks from the specified collection that match given property filters. + + Args: + lang (str): Language collection to use. + property_filters (dict[str, any]): Key-value pairs for filtering. + + Returns: + int: Number of deleted objects (if available, else -1). + """ + retry_count = 0 + max_retries = 2 + + filters = [self._create_property_filter(prop, values) + for prop, values in property_filters.items()] if property_filters else None + if filters: + filters = [f for f in filters if f is not None] + filters = reduce(lambda f1, f2: f1 & f2, filters) if filters else None + + while retry_count < max_retries: + try: + collection, collection_name = self._select_collection(lang) + if collection is None: + logger.error("No working collection selected!") + return 0 + + logger.info(f"Deleting chunks from {collection_name} with filters={property_filters}") + + with self._client_lock: + result = collection.data.delete_many( + where=filters + ) + + self._last_query_time = perf_counter() + + deleted = getattr(result, "objects_deleted", None) + if deleted is None: + logger.info("Deletion executed (count not returned by client)") + return -1 + + logger.info(f"Deleted {deleted} objects") + return deleted + + except Exception as e: + if any(err_type in str(e).lower() for err_type in ['reset', 'closed', 'grpc', 'unavailable']): + retry_count += 1 + logger.warning(f"Connection error during deletion: {e}. Retrying...") + if retry_count == max_retries: + raise e + else: + raise e + + + def ping(self, lang: str) -> dict: + try: + collection, _ = self._select_collection(lang) + with self._client_lock: + collection.query.hybrid("health check query") + return { 'status': 'OK' } + except Exception as e: + return { 'status': 'ERROR', 'error': e } + + + def query(self, query: str, lang: str, property_filters: dict[str] = None, limit: int = 5) -> dict: + """ + Execute a hybrid semantic and keyword query against the active collection with automatic reconnection on idle timeout. + + Args: + query (str): The query string. + lang (str, optional): Language collection to use. If not provided, uses the current one. + property_filters (dict[str, any]): Key-value pairs for metadata filtering. Keys correspond + to document properties (e.g., 'program', 'topic'), and values are the required matches. + Multiple filters are combined using logical AND. + limit (int, optional): Maximum number of results to return. Defaults to 5. + + + Returns: + tuple: A tuple containing the query response and elapsed time. + + Raises: + weaviate.exceptions.WeaviateConnectionError: If no active collection is available. + """ + retry_count = 0 + max_retries = 2 + + filters = [self._create_property_filter(prop, values) + for prop, values in property_filters.items()] if property_filters else None + if filters: + filters = [f for f in filters if f is not None] + filters = reduce(lambda f1, f2: f1 & f2, filters) if filters else None + + while retry_count < max_retries: + try: + collection, collection_name = self._select_collection(lang) + if collection is None: + logger.error("No working collection selected upon starting of the querying!") + return [], 0 + + logger.info(f"Querying collection {collection_name}") + query_start_time = perf_counter() + + with self._client_lock: + resp = collection.query.hybrid( + query=query, + filters=filters, + limit=limit, + return_metadata=MetadataQuery.full() + ) + elapsed = perf_counter() - query_start_time + self._last_query_time = perf_counter() + logger.info(f"Querying retrieved {len(resp.objects)} objects in {elapsed:3.2f} seconds") + + return (resp, elapsed) + except Exception as e: + if any(err_type in str(e).lower() for err_type in ['reset', 'closed', 'grpc', 'unavailable']): + retry_count += 1 + logger.warning(f"Connection error detected: {e}. Retrying...") + + if retry_count == max_retries: + raise e + else: # Probably not a server issue + raise e + + + def _load_properties(self) -> list[Property]: + properties = {} + properties_file = os.path.join(config.weaviate.PROPERTIES_PATH, 'properties.yaml') + if not os.path.exists(properties_file): + logger.warning( + f"Optional file 'properties.yaml' is missing on path: {properties_file}. " + "Falling back to built-in default properties." + ) + return _default_properties() + + try: + import yaml + + with open(properties_file, 'r') as stream: + properties = yaml.safe_load(stream) + except ModuleNotFoundError: + logger.warning( + "PyYAML is not installed. Falling back to built-in default properties " + "for Weaviate collection creation." + ) + return _default_properties() + except Exception as e: + logger.error(f"Failed to load properties from path {properties_file}: {e}") + raise e + + if not properties: + logger.warning("properties.yaml is empty. Falling back to built-in default properties.") + return _default_properties() + + final_properties = [] + for name, params in properties.items(): + try: + data_type = params.get('data_type', '') + dtype = DataType(data_type) + except Exception as e: + logger.error(f"Nonexistent datatype {data_type}") + raise e + + final_properties.append(Property( + name=name, + data_type=dtype, + index_filterable=params.get('filterable', True), + index_searchable=params.get('searchable', True), + skip_vectorization=params.get('skip_vectorization', False), + )) + + return final_properties + + + def _create_collections(self): + """ + Create and initialize language-specific collections. + + Creates collections for all available languages with vector configuration. + """ + properties = self._load_properties() + try: + client = self._init_client() + logger.info('Attempting collections creation...') + + vector_config = ( + Configure.Vectors.text2vec_transformers() if config.weaviate.LOCAL_DATABASE + else Configure.Vectors.text2vec_huggingface( + name='hsg_rag_embeddings', + source_properties=['body'], + model=config.processing.EMBEDDING_MODEL, + ) + ) + + successful_creations = 0 + + with self._client_lock: + for collection_name in _collection_names: + try: + client.collections.create( + name=collection_name, + properties=properties, + vector_config=vector_config + ) + logger.info(f"Created collection {collection_name}") + successful_creations += 1 + except Exception as e: + logger.error(f"Failed to create collection '{collection_name}': {e}") + + self._last_query_time = perf_counter() + + if successful_creations == len(_collection_names): + logger.info('All collections successfully instantiated') + else: + logger.warning(f"Only {successful_creations}/{len(_collection_names)} collections created") + + except Exception as e: + logger.error(f"Collections creation failed: {e}") + self._client = None + raise e + + + def _delete_collections(self): + """ + Delete all existing collections from the database. + + Also removes the hash file if it exists. + """ + try: + client = self._init_client() + logger.info("Initiating deletion of stored collections...") + + deleted_count = 0 + with self._client_lock: + for collection_name in _collection_names: + try: + if client.collections.exists(collection_name): + client.collections.delete(collection_name) + logger.info(f"Deleted collection {collection_name}") + deleted_count += 1 + else: + logger.warning(f"Collection {collection_name} does not exist") + except Exception as e: + logger.error(f"Failed to delete collection {collection_name}: {e}") + + self._last_query_time = perf_counter() + logger.info(f"Deleted {deleted_count}/{len(_collection_names)} collections") + + except Exception as e: + logger.error(f"Collections deletion failed: {e}") + self._client = None + raise e + + + def _reset_collections(self): + self._delete_collections() + self._create_collections() + + + def _collect_chunk_ids(self) -> dict: + client = self._init_client() + try: + ids = [] + with self._client_lock: + for c in client.collections.list_all(simple=False): + coll = client.collections.get(c) + for obj in coll.iterator(): + ids.append(obj.properties['chunk_id']) + return ids + except Exception as e: + logger.error(f"Failed to collect chunk ids: {e}") + raise e + + + def _extract_data(self) -> dict: + client = self._init_client() + try: + schema = [] + objects = {} + with self._client_lock: + for c in client.collections.list_all(simple=False): + coll = client.collections.get(c) + cfg = coll.config.get().to_dict() + schema.append(cfg) + + objects[c] = [] + for obj in coll.iterator(include_vector=True): + objects[c].append({ + "uuid": obj.uuid, + "properties": obj.properties, + "vector": obj.vector, + }) + + return { + 'schema': schema, + 'objects': objects, + } + except Exception as e: + logger.error(f"Failed to extract data from database: {e}") + raise e + + + def _create_backup(self) -> str: + """ + Create a backup of the current database state and stores it under selected backup provider. + + Returns: backup id of the created backup. + """ + try: + if not config.weaviate.BACKUP_METHOD: + raise ValueError('Backup method is not selected!') + if config.weaviate.BACKUP_METHOD not in config.weaviate.BACKUP_METHODS: + raise ValueError(f"Selected backup method 'config.weaviate.BACKUP_METHODS' is not supported!") + if not config.weaviate.BACKUP_PATH: + raise ValueError("Backup directory is not set!") + os.makedirs(config.weaviate.BACKUP_PATH, exist_ok=True) + + backup_id = f"backup_{datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}" + logger.info(f"Initiating backup creation for {self._connection_type} database...") + + match config.weaviate.BACKUP_METHOD: + case 'manual': + import json + + backup_path = os.path.join(config.weaviate.BACKUP_PATH, backup_id) + os.makedirs(backup_path) + + db_data = self._extract_data() + data_backup = { + 'creation_date': datetime.datetime.now().isoformat(), + } + + schema_backup_path = os.path.join(backup_path, 'schema.json') + with open(schema_backup_path, 'w', encoding='utf-8') as f: + json.dump(db_data['schema'], f, indent=2, default=str) + + objects_backup_path = os.path.join(backup_path, 'objects.json') + with open(objects_backup_path, 'w', encoding='utf-8') as f: + json.dump(db_data['objects'], f, indent=2, default=str) + + data_backup_path = os.path.join(backup_path, 'data.json') + with open(data_backup_path, 'w', encoding='utf-8') as f: + json.dump(data_backup, f, indent=2, default=str) + + case 's3': + client = self._init_client() + with self._client_lock: + client.backup.create( + backup_id=backup_id, + backend="s3", + include_collections=_collection_names, + wait_for_completion=True, + ) + case _: + raise NotImplementedError() + + + self._last_query_time = perf_counter() + logger.info(f"Backup '{backup_id}' created successfully") + + return backup_id + except Exception as e: + logger.error(f"Backup creation failed: {e}") + raise e + + + def _restore_backup(self, backup_id: str): + """ + Restore the database state from a backup. + + Restores specified collections from backup. + + Args: + backup_id: ID of the backup to restore from + + Raises: + Exception if backup restoration fails + """ + self._delete_collections() + + try: + if not config.weaviate.BACKUP_METHOD: + raise ValueError('Backup method is not selected!') + if config.weaviate.BACKUP_METHOD not in config.weaviate.BACKUP_METHODS: + raise ValueError(f"Selected backup method 'config.weaviate.BACKUP_METHODS' is not supported!") + if not config.weaviate.BACKUP_PATH: + raise ValueError("Backup directory is not set!") + os.makedirs(config.weaviate.BACKUP_PATH, exist_ok=True) + + backup_path = os.path.join(config.weaviate.BACKUP_PATH, backup_id) + if not os.path.exists(backup_path): + raise RuntimeError(f"Directory for backup 'backup_id' does not exist in the backup directory!") + schema_backup_path = os.path.join(backup_path, 'schema.json') + if not os.path.exists(schema_backup_path): + raise RuntimeError(f"Schema backup is missing in the backup directory!") + objects_backup_path = os.path.join(backup_path, 'objects.json') + if not os.path.exists(objects_backup_path): + raise RuntimeError(f"Objects backup is missing in the backup directory!") + + client = self._init_client() + logger.info(f"Initiating restoration from backup '{backup_id}' for {self._connection_type} database...") + + with self._client_lock: + match config.weaviate.BACKUP_METHOD: + case 'manual': + import json + + with open(schema_backup_path) as f: + schemas = json.load(f) + for cfg in schemas: + client.collections.create_from_dict(cfg) + + with open(objects_backup_path) as f: + data = json.load(f) + for name, objs in data.items(): + logger.info(f"Restoring collection '{name}' with {len(objs)} objects...") + coll = client.collections.get(name) + + with coll.batch.dynamic() as batch: + for o in objs: + o['properties']['date'] = o['properties']['date'] \ + .replace(" ", "T").replace("+00:00", "Z") + batch.add_object( + uuid=o["uuid"], + properties=o["properties"], + vector=o["vector"] + ) + logger.info(f"Collection '{name}' restored successfully") + case 's3': + client.backup.restore( + backup_id=backup_id, + backend="s3", + wait_for_completion=True, + roles_restore="all", + users_restore="all", + ) + case _: + raise NotImplementedError() + + self._last_query_time = perf_counter() + logger.info(f"Backup '{backup_id}' restored successfully") + + except Exception as e: + error_msg = str(e).lower() + if 'connection' in error_msg: + logger.error(f"Connection error during backup restore: {e}. Will reconnect on next operation.") + self._client = None + logger.error(f"Backup restoration failed: {e}") + raise e + + + def _checkhealth(self) -> bool: + """ + Check the connectivity and health status of the Weaviate database. + + Verifies: + - Connection to the database + - Database metadata and version + - Existence of all expected collections + - Module availability + + Returns: + True if all health checks pass, False otherwise + """ + try: + client = self._init_client() + + # Check basic connectivity + is_connected = False + with self._client_lock: + is_connected = client.is_connected() + + connection_status = "✓ OK" if is_connected else "✗ ERROR" + logger.info(f"Connection to {self._connection_type} database: {connection_status}") + + if not is_connected: + logger.error("Database connection check failed") + return False + + # Get and log metadata + try: + with self._client_lock: + metainfo = client.get_meta() + + # Format module information + modules = metainfo.get('modules', {}) + modules_list = list(modules.keys()) if isinstance(modules, dict) else modules + modules_str = ', '.join(str(m) for m in modules_list) if modules_list else 'None' + + # Truncate long module strings for logging + if len(modules_str) > 50: + modules_str = modules_str[:47] + '...' + + # Log connection details + if config.weaviate.LOCAL_DATABASE: + logger.info( + f"Database metadata: " + f"HOSTNAME={metainfo.get('hostname', 'unknown')}, " + f"VERSION={metainfo.get('version', 'unknown')}, " + f"MODULES={modules_str}" + ) + else: + logger.info( + f"Database metadata: " + f"VERSION={metainfo.get('version', 'unknown')}, " + f"MODULES={modules_str}" + ) + + except Exception as e: + logger.warning(f"Could not retrieve database metadata: {e}") + + # Check collection existence + all_collections_exist = True + + with self._client_lock: + for collection_name in _collection_names: + try: + exists = client.collections.exists(collection_name) + status = "✓ OK" if exists else "✗ MISSING" + logger.info(f"Collection '{collection_name}': {status}") + + if not exists: + all_collections_exist = False + + except Exception as e: + logger.error(f"Error checking collection '{collection_name}': {e}") + all_collections_exist = False + + # Update last health check time + self._last_query_time = perf_counter() + + # Log overall health status + if is_connected and all_collections_exist: + logger.info("✓ Database health check PASSED - All systems operational") + return True + else: + logger.warning("✗ Database health check FAILED - Some issues detected") + return False + + except Exception as e: + error_msg = str(e).lower() + if 'connection' in error_msg: + logger.error(f"Connection error during health check: {e}. Will reconnect on next operation.") + self._client = None + logger.error(f"Health check failed: {e}") + return False + + +def parse_arguments(): + """ + Parse command-line arguments for managing Weaviate collections. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ + import argparse + + parser = argparse.ArgumentParser( + description='Weaviate database management utility' + ) + group = parser.add_mutually_exclusive_group() + + group.add_argument( + '-dc', "--delete_collections", + action='store_true', + help='Delete all collections from the database' + ) + group.add_argument( + '-cc', "--create_collections", + action='store_true', + help='Initialize collections for different language contents' + ) + group.add_argument( + '-rc', "--redo_collections", + action='store_true', + help='Delete and recreate all collections' + ) + group.add_argument( + '-ch', "--checkhealth", + action='store_true', + help='Check database connection and collection existence' + ) + group.add_argument( + '-cb', "--create_backup", + action='store_true', + help='Create a backup of the current database state' + ) + group.add_argument( + '-rb', "--restore_backup", + type=str, + metavar='BACKUP_ID', + help='Restore database from a backup (provide backup_id)' + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + service = WeaviateService() + + if args.create_backup: + service._create_backup() + + if args.restore_backup: + service._restore_backup(args.restore_backup) + + if any([args.delete_collections, args.redo_collections]): + service._delete_collections() + + if any([args.create_collections, args.redo_collections]): + service._create_collections() + + if any([args.checkhealth, args.create_collections, args.redo_collections]): + service._checkhealth() diff --git a/src/notification/__init__.py b/src/notification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/notification/notification_center.py b/src/notification/notification_center.py new file mode 100644 index 0000000000000000000000000000000000000000..7546097628996efe1620f8e025c487341c41ee17 --- /dev/null +++ b/src/notification/notification_center.py @@ -0,0 +1,148 @@ +from typing import Literal +import mimetypes +import os +import smtplib +from email.message import EmailMessage + +import requests + +from ..config import NotificationCenterConfig as NC + + +Channel = Literal["email", "slack"] + + +class EmailNotifier: + def __init__(self): + self.enabled = NC.ENABLE_EMAIL_ALERTS + self.smtp_host = NC.SMTP_HOST + self.smtp_port = NC.SMTP_PORT + self.smtp_user = NC.SMTP_USER + self.smtp_password = NC.SMTP_PASSWORD + self.smtp_use_tls = NC.SMTP_USE_TLS + self.from_email = NC.FROM_EMAIL + self.to_emails = self._parse_recipients(NC.TO_EMAIL) + + if self.enabled: + self._validate() + + @staticmethod + def _parse_recipients(value: str | None) -> list[str]: + if not value: + return [] + return [email.strip() for email in value.split(",") if email.strip()] + + def _validate(self) -> None: + missing = [] + + if not self.smtp_host: + missing.append("NOTIFY_SMTP_HOST") + if not self.smtp_user: + missing.append("NOTIFY_SMTP_USER") + if not self.smtp_password: + missing.append("NOTIFY_SMTP_PASSWORD") + if not self.from_email: + missing.append("NOTIFY_FROM_EMAIL") + if not self.to_emails: + missing.append("NOTIFY_TO_EMAIL") + + if missing: + raise ValueError(f"Missing notification email config: {', '.join(missing)}") + + def send( + self, + subject: str, + body: str, + attachments: str | list[str] | None = None, + ) -> None: + if not self.enabled: + return + + if isinstance(attachments, str): + attachments = [attachments] + + msg = EmailMessage() + msg["Subject"] = subject + msg["From"] = self.from_email + msg["To"] = ", ".join(self.to_emails) + msg.set_content(body) + + if attachments: + for file_path in attachments: + if not file_path or not os.path.isfile(file_path): + continue + + mime_type, _ = mimetypes.guess_type(file_path) + mime_type = mime_type or "application/octet-stream" + maintype, subtype = mime_type.split("/", 1) + + with open(file_path, "rb") as f: + msg.add_attachment( + f.read(), + maintype=maintype, + subtype=subtype, + filename=os.path.basename(file_path), + ) + + with smtplib.SMTP(self.smtp_host, self.smtp_port, timeout=20) as server: + if self.smtp_use_tls: + server.starttls() + server.login(self.smtp_user, self.smtp_password) + server.send_message(msg) + + +class SlackNotifier: + def __init__(self): + self.enabled = NC.ENABLE_SLACK_ALERTS + self.webhook_url = NC.SLACK_WEBHOOK_URL + + if self.enabled: + self._validate() + + def _validate(self) -> None: + if not self.webhook_url: + raise ValueError("Missing notification slack config: NOTIFY_SLACK_WEBHOOK_URL") + + def send(self, subject: str, body: str) -> None: + if not self.enabled: + return + + text = f"*{subject}*\n{body}" + + response = requests.post( + self.webhook_url, + json={"text": text}, + timeout=10, + ) + + response.raise_for_status() + + if response.status_code != 200: + raise RuntimeError( + f"Slack notification failed: {response.status_code} {response.text}" + ) + + +class NotificationCenter: + def __init__(self): + self.email = EmailNotifier() + self.slack = SlackNotifier() + + def send_notification( + self, + subject: str, + body: str, + channel: Channel = "email", + attachments: str | list[str] | None = None, + ) -> None: + + match channel: + case "all": + self.email.send(subject, body, attachments) + self.slack.send(subject, body) + case "email": + self.email.send(subject, body, attachments) + case "slack": + self.slack.send(subject, body) + case _: + raise ValueError(f"Unknown notification channel: {channel}") \ No newline at end of file diff --git a/src/pipeline/__init__.py b/src/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/pipeline/pipeline.py b/src/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce644122d0edb42b947f83024dbfba152482e7b --- /dev/null +++ b/src/pipeline/pipeline.py @@ -0,0 +1,212 @@ +from .utils import * +from .processors import * +from ..scraping.scraper import Scraper + +from ..database.weavservice import WeaviateService +from ..utils.logging import get_logger +from ..config import config + +pipelogger = get_logger("pipeline_module") +implogger = get_logger("import_pipeline") + + +class ImportPipeline: + """ + Main pipeline class responsible for importing website and local documents + into the database with deduplication and language-based organization. + """ + + def __init__( + self, + logging_callback = None, + deduplication_callback = None, + ) -> None: + """ + Initialize the import pipeline with optional callbacks for logging and deduplication. + + This sets up the processors for websites and documents and recieves existing chunk IDs + from the database for deduplication purposes. + + Args: + logging_callback (callable, optional): A callback function for logging progress. + Defaults to a placeholder if not provided. + deduplication_callback (callable, optional): A callback function for handling + deduplication decisions. Defaults to a placeholder if not provided. + """ + self._logging_callback = logging_callback or logging_callback_placeholder + self._deduplication_callback = deduplication_callback or deduplication_callback_placeholder + self._docprocessor = DocumentProcessor() + self._service = WeaviateService() + self._ids = self._service._collect_chunk_ids() + + implogger.info('Import pipeline initialization finished!') + + + def import_from_scraper(self, scraper_chunks: dict[str, dict]) -> None: + for lang, chunks in scraper_chunks.items(): + if not chunks: continue + + sources = list(set([chunk.get('source', '') for chunk in chunks])) + self._service.delete_chunks(lang, property_filters={'source': sources}) + self._service.batch_import(data_rows=chunks, lang=lang) + + + def scrape_website(self, target_urls: list[str] | None = None, scrape_all: bool = False) -> None: + target_urls = [url for url in (target_urls or config.scraping.TARGET_URLS or []) if url] + if not target_urls: + implogger.warning("No target URLs configured for scraping.") + return + + scraper = Scraper(scrape_all=scrape_all) + for target_url in target_urls: + self._logging_callback(f"Scraping target {target_url}...", 0) + scraped_chunks = scraper.scrape_target(target_url) + if not scraped_chunks: + self._logging_callback(f"No importable chunks scraped from {target_url}.", 100) + continue + + self._logging_callback(f"Importing scraped chunks from {target_url}...", 90) + self.import_from_scraper(scraped_chunks) + self._logging_callback(f"Finished scraping import for {target_url}.", 100) + + + def import_many_documents(self, sources: list[str]) -> None: + self.import_all(paths=sources) + + + def _import_urls_via_scraper(self, urls: list[str], scrape_all: bool = True) -> None: + urls = [url for url in (urls or []) if url] + if not urls: + return + + scraper = Scraper(scrape_all=scrape_all) + for url in urls: + self._logging_callback(f"Scraping URL {url}...", 0) + scraped_chunks = scraper.scrape_target(url) + if not scraped_chunks: + self._logging_callback(f"Failed to scrape URL {url}!", 100, failed=True) + continue + + self._logging_callback(f"Importing scraped chunks from {url}...", 90) + self.import_from_scraper(scraped_chunks) + self._logging_callback(f"Stored scraped chunks for {url}.", 100) + + + def import_all( + self, + paths: list[str] = None, + urls: list[str] = None, + reset_collections: bool = False, + ) -> None: + """ + Import documents from local paths and/or URLs into the database. + + Processes the provided paths and URLs using the appropriate processors, + combines chunks by language, optionally resets database collections, + and performs batch imports. + + Args: + paths (list[str], optional): List of local file paths to process. Defaults to None. + urls (list[str], optional): List of website URLs to process. Defaults to None. + reset_collections (bool, optional): If True, reset the database collections before importing. + Defaults to False. + """ + chunks = self._pipeline(paths, self._docprocessor, reset_collections) + + if reset_collections: + self._logging_callback('Resetting database collections...', 60) + self._service._reset_collections() + + self._logging_callback('Importing document chunks to database...', 90) + for lang, ch in chunks.items(): + self._service.batch_import(data_rows=ch, lang=lang) + + self._import_urls_via_scraper(urls, scrape_all=True) + + self._logging_callback( + f'Successfully imported {sum([len(ch) for ch in chunks.values()])} document chunks!', + 100 + ) + + + def _pipeline( + self, + sources: list[str], + processor: ProcessorBase, + reset_collections: bool, + ) -> dict: + """ + Internal pipeline to process a list of sources using a given processor. + + Handles processing, deduplication (if not resetting), and organizes unique chunks by language. + If no new unique data is found, logs a warning and returns empty chunks. + + Args: + sources (list[str]): List of sources (paths or URLs) to process. + processor (ProcessorBase): The processor instance to use for handling sources. + reset_collections (bool): If True, skip deduplication. + + Returns: + dict: A dictionary mapping languages to lists of unique chunk dictionaries. + """ + unique_chunks = {lang: [] for lang in config.get('AVAILABLE_LANGUAGES')} + + sources = [s for s in (sources or []) if s != ""] + if not sources: + return unique_chunks + + for source in sources: + self._logging_callback(f'Starting pipeline for {source}...', 0) + result = processor.process(source) + + if not result.chunks: + implogger.error(f"Failed to process {source}!") + self._logging_callback(f"Failed to process {source}!", 100, result, failed=True) + continue + + if not reset_collections: + self._deduplicate(result) + + self._logging_callback(f'Storing chunks for {source}...', 100, result) + unique_chunks[result.lang].extend(result.chunks) + + if all([len(chunks) == 0 for chunks in unique_chunks.values()]): + self._logging_callback('No new data could be extracted from these sources!', 100) + implogger.warning(f"File(s) provided for the insertion do not contain any unique information.") + + return unique_chunks + + + def _deduplicate(self, result: ProcessingResult) -> ProcessingResult: + """ + Remove duplicate chunks based on chunks that are already stored in the database. + + If all chunks are duplicates, invokes the deduplication callback to decide whether + to delete existing duplicates and reimport. Otherwise, returns only unique chunks. + + Args: + result (ProcessingResult): The processing result containing document chunks. + + Returns: + list[dict]: List of unique chunk dictionaries (or all if reimporting duplicates). + """ + self._logging_callback('Performing deduplication...', 80) + unique_chunks = [] + duplicate_ids = [] + for chunk in result.chunks: + chunk_id = chunk['chunk_id'] + if chunk_id in self._ids: + duplicate_ids.append(chunk_id) + else: + unique_chunks.append(chunk) + + implogger.info(f"Found {len(duplicate_ids)} already existing IDs in {len(result.chunks)} collected chunks") + if duplicate_ids: + implogger.info(f"Duplicates found! Calling deduplication callback...") + if self._deduplication_callback(result.source, len(duplicate_ids)): + implogger.info('Duplicated chunks will be reimported as new...') + self._service._delete_by_id(duplicate_ids) + return result + + result.chunks = unique_chunks + return result diff --git a/src/pipeline/processors.py b/src/pipeline/processors.py new file mode 100644 index 0000000000000000000000000000000000000000..6567c3f31b4814f275ba7064d915a9a1ecd53ae2 --- /dev/null +++ b/src/pipeline/processors.py @@ -0,0 +1,303 @@ +from collections import defaultdict +import os, re + +from pathlib import Path +from transformers import AutoTokenizer + +from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer +from docling.datamodel.pipeline_options import PdfPipelineOptions, LayoutOptions +from docling_core.transforms.serializer.markdown import MarkdownDocSerializer +from docling.document_converter import DocumentConverter, PdfFormatOption, InputFormat +from docling.chunking import HybridChunker +from docling_core.types.doc.document import DoclingDocument, TableItem + +from .utils import * + +from ..utils.lang import detect_language +from ..utils.logging import get_logger +from ..config import config + +weblogger = get_logger("website_processor") +datalogger = get_logger("data_processor") + +class ProcessorBase: + def __init__(self) -> None: + """ + Initialize the base processor with document conversion and chunking tools. + + Sets up the PDF pipeline options, document converter, tokenizer, and chunker. + Loads strategies for chunk preparation. + + Args: + logging_callback (callable): A callback function for logging progress. + """ + pipeline_options = PdfPipelineOptions( + do_ocr = False, + generate_page_images = False, + + do_layout_analysis = True, + do_table_structure = True, + do_cell_matching = True, + + layout_options=LayoutOptions( + create_orphan_clusters = True, + keep_empty_clusters = False, + skip_cell_assignment = False, + ), + ) + self._converter: DocumentConverter = DocumentConverter( + format_options={ + InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options), + }, + ) + tokenizer = AutoTokenizer.from_pretrained(config.processing.EMBEDDING_MODEL) + self._chunker = HybridChunker( + tokenizer=HuggingFaceTokenizer( + tokenizer=tokenizer, + max_tokens=config.processing.MAX_TOKENS + ), + serializer_provider=EnhansedSerializerProvider(), + max_tokens=config.processing.MAX_TOKENS, + merge_peers=True + ) + self.strategies_processor = StrategiesProcessor() + self._logging_callback = config.dbapp['logging_callback'] or logging_callback_placeholder + + + def process(self): + """ + Abstract method to be implemented by subclasses for processing sources. + + Raises: + NotImplementedError: If not overridden in a subclass. + """ + raise NotImplementedError("This method is not implemented in ProcessorBase") + + + def convert_to_txt(self, document: DoclingDocument) -> str: + plain_text = [] + for node, _ in document.iterate_items(root=document.body, with_groups=False): + if isinstance(node, TableItem): + df = node.export_to_dataframe(document) + table_str = df.to_string(index=False, na_rep='') + plain_text.append(table_str) + elif hasattr(node, 'text') and node.text: + plain_text.append(node.text.strip()) + return '\n\n'.join(plain_text) + + + def _prepare_chunks(self, document_name: str, document_content: str, chunks: list[str]) -> list[dict]: + """ + Prepare chunks by applying strategies to generate properties for each chunk. + + Args: + document_name (str): The name or identifier of the document. + document_content (str): The full content of the document. + chunks (list[str]): List of text chunks to prepare. + + Returns: + list[dict]: List of dictionaries, each containing properties for a chunk. + """ + prepared_chunks = [] + for chunk in chunks: + prepared_chunks.append({ + prop: self.strategies_processor.apply_strategy( + strategy_name=prop, + arguments=StrategyArguments(document_name, document_content, chunk), + ) + for prop in self.strategies_processor.list_strategies() + }) + + return prepared_chunks + + + def _clean_content(self, document_content: str) -> str: + """ + Clean the document content by removing garbage symbols and normalizing whitespace. + + Handles specific replacements for punctuation, symbols, and line breaks. + + Args: + document_content (str): The raw document content to clean. + + Returns: + str: The cleaned document content. + """ + cleaned = re.sub(r'\s+/\s+', '/', document_content) + cleaned = re.sub(r'\s+\.\s+', '.', cleaned) + cleaned = re.sub(r',\s+', '.', cleaned) + cleaned = re.sub(r'\s+\|\s+', ' ', cleaned) + cleaned = re.sub(r'\/\s+', '/', cleaned) + cleaned = re.sub(r'\s+/','/', cleaned) + cleaned = re.sub(r'\s+\.', '.', cleaned) + cleaned = re.sub(r'(\d+)\s*,\s*(\d{4})', r'\1', cleaned) + cleaned = re.sub(r'(\d+)\s*/\s*(\d+)', r'\1', cleaned) + cleaned = re.sub(r'\.(\d{4})', r'.\1', cleaned) + + cleaned = cleaned.replace('ä', 'ä').replace('ö', 'ö').replace('ü', 'ü') + + cleaned = re.sub(r'\n\s*\n+', '\n\n', cleaned) + cleaned = re.sub(r' +', ' ', cleaned) + + return cleaned + + + def _extract_document_content(self, document: DoclingDocument) -> str: + """ + Extract and compile text content from the document into a single string. + + Organizes text items by page, sorts them by position, and joins them + while handling line breaks and spacing. + + Args: + document (DoclingDocument): The document object to extract content from. + + Returns: + str: The cleaned, compiled text content. + """ + page_texts = defaultdict(list) + for text_item in document.texts: + if not text_item.text.strip(): + continue + + prov = text_item.prov[0] if text_item.prov else None + if prov: + page_number = prov.page_no + bbox = prov.bbox + page_texts[page_number].append({ + 'text': text_item.text.strip(), + 'top': bbox.t, + 'left': bbox.l, + 'bottom': bbox.b, + }) + + full_page_texts = [] + for page_number in sorted(page_texts.keys()): + text_items = sorted( + page_texts[page_number], + key=lambda text: (-text['top'], text['left']), + ) + + content = [] + last_bottom = None + + line_treshold = 15 + + for item in text_items: + text = item['text'] + + if last_bottom is not None and (last_bottom - item['bottom'] > line_treshold): + if content: + full_page_texts.append(' '.join(content)) + content = [] + + if last_bottom - item['bottom'] > 50: + full_page_texts.append("") + + content.append(text) + last_bottom = item['bottom'] + + if content: + full_page_texts.append(' '.join(content)) + + full_text = '\n\n'.join(full_page_texts) + cleaned_text = self._clean_content(full_text) + + return cleaned_text + + + def _collect_chunks(self, document: DoclingDocument) -> list[str]: + """ + Collect contextualized chunks from the document using the chunker. + + Args: + document (DoclingDocument): The document to chunk. + + Returns: + list[str]: List of enriched text chunks. + """ + chunks = [] + for base_chunk in self._chunker.chunk(dl_doc=document): + enriched = self._chunker.contextualize(chunk=base_chunk) + chunks.append(enriched) + return chunks + + + def _collect_chunks_fallback(self, document_content: str) -> list[str]: + """ + Fallback method to chunk the document content manually using tokenization. + + Splits the content into overlapping chunks based on token limits. + + Args: + document_content (str): The full content extracted from document. + + Returns: + list[str]: List of text chunks. + """ + tokenizer_wrapper = self._chunker.tokenizer + tokenizer = getattr(tokenizer_wrapper, 'tokenizer', tokenizer_wrapper) + + tokens = tokenizer.encode(document_content) + chunk_size = self._chunker.max_tokens + overlap = 50 + + collected_chunks = [] + for i in range(0, len(tokens), chunk_size-overlap): + chunk_tokens = tokens[i:i+chunk_size] + chunk = tokenizer.decode( + chunk_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaces=True + ) + collected_chunks.append(chunk) + + return collected_chunks + + +class DocumentProcessor(ProcessorBase): + def process(self, source: Path | str) -> ProcessingResult: + """ + Process a single local document, converting it to text, chunking, and preparing for import. + + Handles document conversion, chunk collection (with fallback if needed), + chunk preparation, and language detection. + + Args: + source (Path | str): Path to the document to process. + + Returns: + ProcessingResult: The result containing chunks, source name, and detected language. + Returns None if the source does not exist or processing fails. + """ + if not os.path.exists(source) or not os.path.isfile(source): + datalogger.error(f"Failed to initiate processing pipeline for source {source}: file does not exist") + return ProcessingResult(source=source, chunks=None, lang='') + + document_name = os.path.basename(source) + datalogger.info(f"Initiating processing pipeline for source {document_name}") + self._logging_callback(f'Converting source {document_name}...', 20) + document = self._converter.convert(source).document + + self._logging_callback(f'Collecting chunks from {document_name}...', 40) + collected_chunks = self._collect_chunks(document) + document_content = MarkdownDocSerializer(doc=document).serialize().text + + if len(collected_chunks) <= 1: # Document content manual extraction + document_content = self._extract_document_content(document) + document = self._converter.convert_string( + content=document_content, + format=InputFormat.MD + ).document + collected_chunks = self._collect_chunks(document) + + self._logging_callback(f'Preparing chunks for {document_name} for importing...', 60) + prepared_chunks = self._prepare_chunks(document_name, document_content, collected_chunks) + + datalogger.info(f"Successfully collected {len(prepared_chunks)} chunks from {document_name}") + + return ProcessingResult( + chunks=prepared_chunks, + source=document_name, + lang=detect_language(document_content), + ) diff --git a/src/pipeline/utilclasses.py b/src/pipeline/utilclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..b28b04f643122b019e912540f228c8ed20be9eeb --- /dev/null +++ b/src/pipeline/utilclasses.py @@ -0,0 +1,3 @@ + + + diff --git a/src/pipeline/utils/__init__.py b/src/pipeline/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e323632ece2f5722585e466a0651be9ebbe5ce99 --- /dev/null +++ b/src/pipeline/utils/__init__.py @@ -0,0 +1,3 @@ +from .strategies_processor import StrategyArguments, StrategiesProcessor +from .serializer import EnhansedSerializerProvider +from .utilclasses import * \ No newline at end of file diff --git a/src/pipeline/utils/serializer.py b/src/pipeline/utils/serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..3dab668a2fcb60bb03d883f3df6cbe1da5e20adb --- /dev/null +++ b/src/pipeline/utils/serializer.py @@ -0,0 +1,58 @@ +from docling_core.transforms.chunker.hierarchical_chunker import ChunkingDocSerializer, ChunkingSerializerProvider +from docling_core.transforms.serializer.base import BaseTableSerializer, SerializationResult +from docling_core.transforms.serializer.common import create_ser_result +from docling_core.types.doc.document import RichTableCell + +class EnhancedTableSerializer(BaseTableSerializer): + def serialize(self, *, item, doc_serializer, doc, **kwargs) -> SerializationResult: + if item.self_ref in doc_serializer.get_excluded_refs(**kwargs): + return create_ser_result(text='') + + grid = item.data.grid + if not grid: + return create_ser_result(text='') + + row_cells = [] + for row in grid: + clean_row = [] + for cell in row: + if isinstance(cell, RichTableCell): + ser = doc_serializer.serialize(item=cell.ref.resolve(doc), **kwargs) + clean_row.append(ser.text.strip()) + else: + clean_row.append((cell.text or "").strip()) + if any(c for c in clean_row): + row_cells.append(clean_row) + + headers = row_cells[0] + data_rows = row_cells[1:] + + lines = [] + + for row in data_rows: + if len(row) < 2 or not row[0].strip(): + continue + + main_key = row[0].strip().replace('\n', ' ') + top_line = f'- {main_key}:' + lines.append(top_line) + + for i in range(1, len(row)): + value = row[i].strip().replace('\n', ' ') + if not value: continue + sub_header = headers[i].strip().replace('\n', ' ') if i < len(headers) else f"" + sub_line = f' - {sub_header}: {value}' + lines.append(sub_line) + + lines.append("") + + final_text = "\n".join(lines).rstrip() + return create_ser_result(text=final_text, span_source=item) + + +class EnhansedSerializerProvider(ChunkingSerializerProvider): + def get_serializer(self, doc): + return ChunkingDocSerializer( + doc=doc, + table_serializer=EnhancedTableSerializer(), + ) diff --git a/src/pipeline/utils/strategies_processor.py b/src/pipeline/utils/strategies_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..aba932d65ed28f7665ed9e616b8e8f491353c439 --- /dev/null +++ b/src/pipeline/utils/strategies_processor.py @@ -0,0 +1,74 @@ +import os, re, importlib.util +from dataclasses import dataclass + +from src.config import config +from src.utils.logging import get_logger + +logger = get_logger('pipeline.strats') + +@dataclass +class StrategyArguments: + name: str = None + content: str = None + chunk: str = None + +class StrategiesProcessor: + def __init__(self) -> None: + os.makedirs(config.weaviate.STRATEGIES_PATH, exist_ok=True) + + self._strategies: dict = self._load_strategies() + + def list_strategies(self) -> list[str]: + return self._strategies.keys() + + def apply_strategy(self, strategy_name: str, arguments: StrategyArguments | dict): + if strategy_name not in self._strategies.keys(): + raise ValueError(f"Cannot apply strategy '{strategy_name}': strategy not found!") + + try: + strategy = self._strategies[strategy_name] + run_result = None + if isinstance(arguments, StrategyArguments): + run_result = strategy.run(arguments.name, arguments.content, arguments.chunk) + else: + run_result = strategy.run( + arguments.get('document_name', ""), + arguments.get('document_content', ""), + arguments.get('chunk', None) + ) + return run_result + except Exception as e: + raise RuntimeError(f"Cannot apply strategy '{strategy_name}': {e}") + + + def _load_strategies(self) -> dict: + loaded_strategies = dict() + for strat_file in os.listdir(config.weaviate.STRATEGIES_PATH): + strat_name = self._extract_strategy_name(strat_file) + if not strat_name: continue + + strat_path = os.path.join(config.weaviate.STRATEGIES_PATH, strat_file) + + spec = importlib.util.spec_from_file_location( + name=strat_name, + location=strat_path + ) + strategy = importlib.util.module_from_spec(spec) + spec.loader.exec_module(strategy) + + if not hasattr(strategy, 'run'): + logger.warning(f"Found strategy '{strat_name}' has no valid run() function!") + continue + + loaded_strategies[strat_name] = strategy + + logger.info(f"Loaded {len(loaded_strategies.keys())} strategies") + return loaded_strategies + + + def _extract_strategy_name(self, strat_file: str) -> str: + match = re.fullmatch(r'^strat_(.*)\.py$', strat_file) + return match.group(1) if match else None + + + diff --git a/src/pipeline/utils/utilclasses.py b/src/pipeline/utils/utilclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..e2adaf29a0195d36effed53eac603cdeb69d7956 --- /dev/null +++ b/src/pipeline/utils/utilclasses.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +def logging_callback_placeholder(*_): + pass + +def deduplication_callback_placeholder(*_) -> bool: + return False + +@dataclass +class ProcessingResult: + chunks: list[dict] + source: str + lang: str diff --git a/src/rag/__init__.py b/src/rag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/rag/agent_chain.py b/src/rag/agent_chain.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8b39cd6898840c88f34951d3f02bb24d8ffc6b --- /dev/null +++ b/src/rag/agent_chain.py @@ -0,0 +1,1022 @@ +from langchain_core.runnables import RunnableConfig +from langsmith import traceable +from langchain.tools import tool +from langchain.agents import create_agent +from langchain_core.messages import ( + HumanMessage, + AIMessage, + SystemMessage, +) +from langchain.agents.middleware import ModelFallbackMiddleware +from langchain.agents.structured_output import ProviderStrategy + +import uuid +import json +import os +import re +import random +import glob +from datetime import datetime + +from src.database.weavservice import WeaviateService + +from src.rag.utilclasses import * +from src.const.agent_response_constants import * +from src.rag.middleware import AgentChainMiddleware as chainmdw +from src.rag.prompts import PromptConfigurator as promptconf +from src.rag.models import ModelConfigurator as modelconf +from src.rag.input_handler import InputHandler +from src.rag.response_formatter import ResponseFormatter +from src.rag.scope_guardian import ScopeGuardian +# from src.rag.quality_score_handler import QualityEvaluationResult, QualityScoreHandler +from src.rag.language_detection import LanguageDetector + +from src.utils.logging import get_logger +from src.utils.lang import get_language_name +from src.config import config + +from ..cache.cache import Cache + +chain_logger = get_logger('agent_chain') + + +class ExecutiveAgentChain: + def __init__(self, language: str = 'en', session_id: str | None = None) -> None: + self._initial_language = language + self._stored_language = language + self._dbservice = WeaviateService() + self._agents, self._config = self._init_agents() + self._conversation_history = [] + self._cache = Cache.get_cache() + + # Confidence scoring is intentionally disabled here because the extra + # model call adds latency and has not been reliable enough to justify it. + # if config.chain.EVALUATE_RESPONSE_QUALITY: + # self._quality_handler = QualityScoreHandler() + self._language_detector = LanguageDetector() + + # Generate unique user ID for this session + self._user_id = session_id or str(uuid.uuid4()) + + # Initialize conversation state with user profile tracking + self._conversation_state: ConversationState = { + 'session_id': self._user_id, + 'user_id': self._user_id, + 'user_language': None, + 'user_name': None, + 'experience_years': None, + 'leadership_years': None, + 'field': None, + 'interest': None, + 'qualification_level': None, + 'program_interest': [], + 'suggested_program': None, + 'handover_requested': None, + 'topics_discussed': [], + 'preferences_known': False + } + + # Track scope violations for escalation + self._scope_violation_counts: dict[str, int] = {} + self._aggressive_violation_count = 0 + + chain_logger.info(f"Initialized new Agent Chain for language '{language}' with user_id: {self._user_id}") + + def _retrieve_context(self, query: str, program: str, language: str = None): + """ + Send the query to the vector database to retrieve additional information about the program. + + Args: + query: Keywords depicting information you want to retrieve in the primary language. + program: Name of the program (either 'emba', 'iemba' or 'emba x') for which the information is requested. + language: Optional parameter (either 'en' for English language or 'de' for German language). This parameter selects the language of the database to query from. The input query must be written in the same language as the selected language. Use this parameter only if there's not enough information in your main language. + """ + lang = language if language in ['en', 'de'] else self._initial_language + try: + response, _ = self._dbservice.query( + query=query, + lang=lang, + limit=config.get('TOP_K_RETRIEVAL'), + property_filters={ + 'programs': [program], + }, + ) + serialized = '\n\n'.join([doc.properties.get('body', '') for doc in response.objects]) + return serialized + except Exception as e: + raise e + + def _call_emba_agent(self, query: str) -> str: + """ + Invokes the EMBA support agent to retrieve more detailed information about the EMBA program. + + Args: + query: Query to the EMBA support agent. Provide collected user data in the query if possible. + """ + try: + structured_response = self._query( + agent=self._agents['emba'], + messages=[HumanMessage(query)], + thread_id=f"emba_{hash(query)}", + ) + return structured_response.response + except Exception as e: + chain_logger.error(f"EMBA Agent error: {e}") + raise RuntimeError("Unable to retrieve EMBA information at this time.") + + def _call_iemba_agent(self, query: str) -> str: + """ + Invokes the IEMBA support agent to retrieve more detailed information about the IEMBA program. + + Args: + query: Query to the IEMBA support agent. Provide collected user data in the query if possible. + """ + try: + structured_response = self._query( + agent=self._agents['iemba'], + messages=[HumanMessage(query)], + thread_id=f"emba_{hash(query)}", + ) + return structured_response.response + except Exception as e: + chain_logger.error(f"IEMBA Agent error: {e}") + raise RuntimeError("Unable to retrieve IEMBA information at this time.") + + def _call_embax_agent(self, query: str) -> str: + """ + Invokes the emba X support agent to retrieve more detailed information about the emba X program. + + Args: + query: Query to the emba X support agent. Provide collected user data in the query if possible. + """ + try: + structured_response = self._query( + agent=self._agents['embax'], + messages=[HumanMessage(query)], + thread_id=f"emba_{hash(query)}", + ) + return structured_response.response + except Exception as e: + chain_logger.error(f"emba X Agent error: {e}") + raise RuntimeError("Unable to retrieve emba X information at this time.") + + def _init_agents(self): + config: RunnableConfig = { + 'configurable': {'thread_id': 0} + } + fallback_middleware = ModelFallbackMiddleware( + *modelconf.get_fallback_models() + ) + tool_retrieve_context = tool( + name_or_callable='retrieve_context', + runnable=self._retrieve_context, + return_direct=False, + parse_docstring=True, + ) + tools_agent_calling = [ + tool( + name_or_callable='call_emba_agent', + runnable=self._call_emba_agent, + return_direct=False, + parse_docstring=True, + ), + tool( + name_or_callable='call_iemba_agent', + runnable=self._call_iemba_agent, + return_direct=False, + parse_docstring=True, + ), + tool( + name_or_callable='call_embax_agent', + runnable=self._call_embax_agent, + return_direct=False, + parse_docstring=True, + ), + ] + agents = { + 'lead': create_agent( + name="lead_agent", + model=modelconf.get_main_agent_model(), + tools=tools_agent_calling, + state_schema=LeadInformationState, + system_prompt=promptconf.get_configured_agent_prompt('lead', language=self._initial_language), + middleware=[ + chainmdw.get_tool_wrapper(), + chainmdw.get_model_wrapper(), + fallback_middleware, + ], + context_schema=AgentContext, + response_format=ProviderStrategy( + StructuredAgentResponse + ), + ), + } + for agent in ['emba', 'iemba', 'embax']: + agents[agent] = create_agent( + name=f"{agent}_agent", + model=modelconf.get_subagent_model(), + tools=[tool_retrieve_context], + state_schema=LeadInformationState, + system_prompt=promptconf.get_configured_agent_prompt(agent, language=self._initial_language), + middleware=[ + fallback_middleware, + chainmdw.get_tool_wrapper(), + chainmdw.get_model_wrapper(), + ], + context_schema=AgentContext, + ) + return agents, config + + def _extract_experience_years(self, conversation: str) -> int | None: + """Extract years of professional experience from conversation text.""" + # Look for patterns like "10 years", "5 years experience", etc. + patterns = [ + r'(\d+)\s*years?\s*(?:of\s*)?(?:experience|work)', + r'(\d+)\s*years?\s*in\s*(?:the\s*)?(?:field|industry)', + r'working\s*for\s*(\d+)\s*years?', + r'(\d+)\s*Jahre\s*(?:Erfahrung|Berufserfahrung)', # German + ] + for pattern in patterns: + match = re.search(pattern, conversation, re.IGNORECASE) + if match: + return int(match.group(1)) + return None + + def _extract_leadership_years(self, conversation: str) -> int | None: + """Extract years of leadership experience from conversation text.""" + patterns = [ + r'(\d+)\s*years?\s*(?:of\s*)?(?:leadership|management|managing)', + r'(?:lead|led|manage|managed)\s*(?:for\s*)?(\d+)\s*years?', + r'(\d+)\s*Jahre\s*(?:Führungserfahrung|Führung)', # German + ] + for pattern in patterns: + match = re.search(pattern, conversation, re.IGNORECASE) + if match: + return int(match.group(1)) + return None + + def _extract_field(self, conversation: str) -> str | None: + """Extract professional field/industry from conversation text.""" + # Common fields mentioned in executive education + fields = [ + 'finance', 'banking', 'technology', 'tech', 'IT', 'healthcare', + 'consulting', 'manufacturing', 'retail', 'marketing', 'sales', + 'engineering', 'pharma', 'telecommunications', 'energy', + 'Finanzwesen', 'Technologie', 'Gesundheitswesen', 'Beratung' # German + ] + conversation_lower = conversation.lower() + for field in fields: + if field.lower() in conversation_lower: + return field.capitalize() + return None + + def _extract_interest(self, conversation: str) -> str | None: + """Extract content interests from conversation text.""" + # Look for interest indicators + interests = [ + 'strategy', 'innovation', 'leadership', 'digital transformation', + 'finance', 'operations', 'marketing', 'entrepreneurship', + 'social impact', 'technology', 'management', + 'Strategie', 'Innovation', 'Führung', 'Digitalisierung' # German + ] + conversation_lower = conversation.lower() + found_interests = [interest for interest in interests + if interest.lower() in conversation_lower] + return ', '.join(found_interests) if found_interests else None + + def _extract_name(self, conversation: str) -> str | None: + """Extract user's name from conversation text.""" + patterns = [ + r"(?:my name is|i'm|i am|call me)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)", + r"(?:this is|it's)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)", + r"(?:ich heiße|mein Name ist|ich bin)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)", # German + ] + for pattern in patterns: + match = re.search(pattern, conversation, re.IGNORECASE) + if match: + name = match.group(1).strip() + # Filter out common words that might be误ly matched + excluded = ['interested', 'looking', 'working', 'searching', 'asking'] + if name.lower() not in excluded: + return name + return None + + def _detect_handover_request(self, conversation: str) -> bool: + """Detect if user requested appointment, callback, or contact.""" + # Keywords indicating handover request + handover_keywords = [ + 'appointment', 'call me', 'contact me', 'schedule', 'meeting', + 'callback', 'reach out', 'follow up', 'get in touch', 'speak with', + 'talk to', 'consultation', 'discuss with', 'meet with', + 'Termin', 'Rückruf', 'kontaktieren', 'Gespräch', 'anrufen', # German + 'zurückrufen', 'Beratung', 'treffen' + ] + conversation_lower = conversation.lower() + return any(keyword.lower() in conversation_lower for keyword in handover_keywords) + + def _previous_response_offered_booking(self) -> bool: + """Return True if the latest assistant turn offered booking as a next step.""" + booking_offer_terms = [ + "appointment slots", + "book an appointment", + "book a consultation", + "appointment booking", + "show you available appointments", + "show appointment options", + "terminbuchung", + "termin buchen", + "termine anzeigen", + "verfügbare termine", + "beratungstermin", + ] + + for message in reversed(self._conversation_history): + if not isinstance(message, AIMessage): + continue + content = getattr(message, "content", "") or getattr(message, "text", "") + if isinstance(content, list): + content = " ".join(str(part) for part in content) + content_lower = str(content).lower() + return any(term in content_lower for term in booking_offer_terms) + + return False + + def _get_latest_ai_message_content(self, skip_latest: bool = False) -> str: + """Return the latest assistant message content from conversation history.""" + ai_messages_seen = 0 + + for message in reversed(self._conversation_history): + if not isinstance(message, AIMessage): + continue + + ai_messages_seen += 1 + if skip_latest and ai_messages_seen == 1: + continue + + content = getattr(message, "content", "") or getattr(message, "text", "") + if isinstance(content, list): + return " ".join(str(part) for part in content) + return str(content) + + return "" + + def _is_booking_preference_follow_up(self, query: str) -> bool: + """Detect short follow-up answers that continue an active booking flow.""" + query_lower = query.lower().strip() + if not query_lower: + return False + + preference_terms = [ + "online", + "on-site", + "onsite", + "in person", + "in-person", + "st.gallen", + "st. gallen", + "morning", + "mornings", + "afternoon", + "afternoons", + "evening", + "beginning of the week", + "start of the week", + "end of the week", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "morgens", + "vormittag", + "vormittags", + "nachmittag", + "nachmittags", + "abends", + "wochenanfang", + "anfang der woche", + "ende der woche", + "montag", + "dienstag", + "mittwoch", + "donnerstag", + "freitag", + "vor ort", + "vor-ort", + "persönlich", + "persoenlich", + "hybrid", + ] + + if any(term in query_lower for term in preference_terms): + return True + + return False + + def _previous_response_requested_booking_preferences(self) -> bool: + """Return True when the previous assistant turn asked clarifying booking questions.""" + content_lower = self._get_latest_ai_message_content().lower() + if not content_lower: + return False + + booking_context_terms = [ + "appointment options", + "available appointments", + "available slots", + "appointment slots", + "online-terminoptionen", + "terminoptionen", + "verfügbare slots", + "verfügbare termine", + "beratungsgespräch", + "beratung", + ] + clarification_terms = [ + "do you prefer", + "would you prefer", + "which programme", + "which program", + "one short question", + "final question", + "when i know this", + "bitte noch kurz", + "eine kurze rückfrage", + "eine kurze letzte frage", + "bevorzugen sie", + "haben sie eine tagespräferenz", + "sobald ich das weiss", + "damit die slots besser passen", + ] + + return ( + any(term in content_lower for term in booking_context_terms) + and any(term in content_lower for term in clarification_terms) + ) + + def _response_commits_to_showing_booking_widget(self, response: str) -> bool: + """Detect when the assistant says booking options are being shown now.""" + response_lower = response.lower() + + positive_terms = [ + "i can show you", + "contact details and available appointment slots are shown below", + "appointment options are shown below", + "available slots are shown below", + "i can now show you", + "ich kann ihnen nun", + "ich kann ihnen jetzt", + "unten werden ihnen", + "unten finden sie", + "unten sehen sie", + "terminoptionen anzeigen", + "verfügbaren slots", + "verfügbaren termine", + ] + defer_terms = [ + "if you would like", + "if you later wish", + "you can ask me", + "if that would be helpful", + "sobald ich das weiss", + "wenn ich das weiss", + "damit die slots besser passen", + "bitte noch kurz", + "eine kurze rückfrage", + "eine kurze letzte frage", + "bevorzugen sie", + "have you got a preference", + "do you prefer", + "would you prefer", + "which programme", + "which program", + ] + + return ( + any(term in response_lower for term in positive_terms) + and not any(term in response_lower for term in defer_terms) + ) + def _is_explicit_booking_intent(self, query: str) -> bool: + """Detect whether the user is actively asking to book or accepting a booking offer.""" + query_lower = query.lower() + direct_booking_terms = [ + "book", + "schedule", + "appointment", + "consultation", + "need a consultation", + "personal consultation", + "speak with", + "talk to an advisor", + "talk to admissions", + "connect me", + "show me available", + "show appointment", + "available slots", + "termin", + "termin buchen", + "termin vereinbaren", + "beratungstermin", + "beratungsgespräch", + "ich brauche eine beratung", + "ich möchte eine beratung", + "ich will eine beratung", + "beratung für", + "persönliche beratung", + "persoenliche beratung", + "mit jemandem sprechen", + "mit admissions sprechen", + "mit der zulassung sprechen", + "termine anzeigen", + "verfügbare termine", + ] + rejection_terms = [ + "do not want", + "don't want", + "no appointment", + "not book", + "not schedule", + "no thanks", + "no thank you", + "kein termin", + "keinen termin", + "keine beratung", + "nicht buchen", + "nicht vereinbaren", + "nein danke", + ] + acceptance_terms = [ + "yes", + "yes please", + "please do", + "that would be helpful", + "show me", + "ja", + "ja bitte", + "gerne", + "bitte", + "mach das", + "zeige", + ] + + def contains_term(term: str) -> bool: + if term in {"yes", "ja", "bitte"}: + return re.search(rf"\b{re.escape(term)}\b", query_lower) is not None + return term in query_lower + + if any(contains_term(term) for term in rejection_terms): + return False + + if any(contains_term(term) for term in direct_booking_terms): + return True + + return ( + self._previous_response_offered_booking() + and any(contains_term(term) for term in acceptance_terms) + ) + + def _determine_suggested_program(self) -> str | None: + """Determine recommended program based on user profile.""" + state = self._conversation_state + + # If program interest was explicitly mentioned + if state['program_interest']: + return state['program_interest'][0] + + # Make recommendation based on profile + experience = state.get('experience_years', 0) or 0 + leadership = state.get('leadership_years', 0) or 0 + + # EMBA: 5+ years experience, 2+ years leadership + if experience >= 5 and leadership >= 2: + return 'EMBA' + # IEMBA: International focus, 3+ years experience + elif experience >= 3: + return 'IEMBA' + # EMBA X: Digital/Innovation focus + elif state.get('interest') and any(kw in state.get('interest', '').lower() + for kw in ['digital', 'innovation', 'technology']): + return 'emba X' + + return None + + def _update_conversation_state(self, user_query: str, agent_response: str) -> None: + """Update conversation state by extracting information from the conversation.""" + if not config.convstate.TRACK_USER_PROFILE: + return + + # Combine query and response for analysis + conversation_text = f"{user_query} {agent_response}" + + # Extract profile information + if not self._conversation_state.get('experience_years'): + exp_years = self._extract_experience_years(conversation_text) + if exp_years: + self._conversation_state['experience_years'] = exp_years + chain_logger.info(f"Extracted experience years: {exp_years}") + + if not self._conversation_state.get('leadership_years'): + lead_years = self._extract_leadership_years(conversation_text) + if lead_years: + self._conversation_state['leadership_years'] = lead_years + chain_logger.info(f"Extracted leadership years: {lead_years}") + + if not self._conversation_state.get('field'): + field = self._extract_field(conversation_text) + if field: + self._conversation_state['field'] = field + chain_logger.info(f"Extracted field: {field}") + + if not self._conversation_state.get('interest'): + interest = self._extract_interest(conversation_text) + if interest: + self._conversation_state['interest'] = interest + chain_logger.info(f"Extracted interest: {interest}") + + # Extract name + if not self._conversation_state.get('user_name'): + name = self._extract_name(conversation_text) + if name: + self._conversation_state['user_name'] = name + chain_logger.info(f"Extracted name: {name}") + + # Detect handover request from the user only; assistant soft offers should not count. + if self._detect_handover_request(user_query): + self._conversation_state['handover_requested'] = True + chain_logger.info("Handover request detected") + + # Check for program mentions + programs = ['EMBA', 'IEMBA', 'EMBA X'] + for program in programs: + if program.lower() in conversation_text.lower(): + if program not in self._conversation_state['program_interest']: + self._conversation_state['program_interest'].append(program) + + # Update suggested program + suggested = self._determine_suggested_program() + if suggested and not self._conversation_state.get('suggested_program'): + self._conversation_state['suggested_program'] = suggested + chain_logger.info(f"Suggested program: {suggested}") + + def _log_user_profile(self) -> None: + """Log user profile to JSON file.""" + if not config.convstate.TRACK_USER_PROFILE: + return + + try: + # Create logs directory if it doesn't exist + log_dir = os.path.join('logs', 'user_profiles') + os.makedirs(log_dir, exist_ok=True) + + # Create profile data + profile_data = { + 'session_id': self._conversation_state['session_id'], + 'user_id': self._conversation_state['user_id'], + 'name': self._conversation_state.get('user_name'), + 'timestamp': datetime.now().isoformat(), + 'experience_years': self._conversation_state.get('experience_years'), + 'leadership_years': self._conversation_state.get('leadership_years'), + 'field': self._conversation_state.get('field'), + 'interest': self._conversation_state.get('interest'), + 'suggested_program': self._conversation_state.get('suggested_program'), + 'handover': self._conversation_state.get('handover_requested'), + 'user_language': self._conversation_state.get('user_language'), + 'program_interest': self._conversation_state.get('program_interest', []), + } + + # Log file path with timestamp + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'profile_{self._user_id}_{timestamp}.json') + + # Write to file + with open(log_file, 'w', encoding='utf-8') as f: + json.dump(profile_data, f, indent=2, ensure_ascii=False) + + chain_logger.info(f"User profile logged to {log_file}") + + except Exception as e: + chain_logger.error(f"Failed to log user profile: {e}") + + def wipe_session_data(self) -> None: + """Delete in-memory session data and on-disk profile files (GDPR withdrawal).""" + + # --- 1) In-memory wipe --- + self._conversation_history = [] + self._conversation_state.update({ + 'user_language': None, + 'user_name': None, + 'experience_years': None, + 'leadership_years': None, + 'field': None, + 'interest': None, + 'qualification_level': None, + 'program_interest': [], + 'suggested_program': None, + 'handover_requested': None, + 'topics_discussed': [], + 'preferences_known': False + }) + self._scope_violation_counts = {} + self._aggressive_violation_count = 0 + + # --- 2) On-disk wipe (delete profile__*.json) --- + if not self._user_id: + chain_logger.warning("wipe_session_data called without user_id – skipping file deletion") + return + + pattern = os.path.join( + "logs", + "user_profiles", + f"profile_{self._user_id}_*.json" + ) + + for path in glob.glob(pattern): + try: + os.remove(path) + chain_logger.info(f"Deleted profile file: {path}") + except OSError as e: + chain_logger.error(f"Failed to delete {path}: {e}") + + def generate_greeting(self) -> str: + greeting_message = random.choice(GREETING_MESSAGES[self._stored_language]) + return greeting_message + + @traceable + def query(self, query: str) -> LeadAgentQueryResponse: + """ + Phase 1: Validation, Scope-Check and language detection. + Does not call the agent directly. + """ + # Remember fallback language + current_language = self._stored_language + + if len(self._conversation_history) >= config.convstate.MAX_CONVERSATION_TURNS: + return LeadAgentQueryResponse( + response = CONVERSATION_END_MESSAGE[current_language], + language = current_language, + max_turns_reached = True, + relevant_programs=[], + processed_query = query + ) + + # 2. Input Processing + processed_query, is_valid = InputHandler.process_input( + query, + [msg for msg in self._conversation_history if isinstance(msg, (HumanMessage, AIMessage))] + ) + + if not is_valid or not processed_query: + chain_logger.warning(f"Invalid input received: '{query}'") + return LeadAgentQueryResponse( + response=NOT_VALID_QUERY_MESSAGE[self._stored_language], + language=current_language, + processed_query=query + ) + + # Log check + if processed_query != query: + chain_logger.info(f"Interpreted input '{query}' as '{processed_query}'") + + # 3. Language Detection + # First: Check for explicit language switch request (overrides lock) + explicit_switch = self._language_detector.detect_explicit_switch_request(processed_query) + if explicit_switch: + self._stored_language = explicit_switch + current_language = explicit_switch + self._conversation_state['user_language'] = explicit_switch + elif self._language_detector.is_language_neutral_program_reference(processed_query): + chain_logger.info( + f"Skipping language re-detection for language-neutral programme reference: '{processed_query}'" + ) + current_language = self._stored_language + else: + # Count user messages in conversation history + user_message_count = len([m for m in self._conversation_history if isinstance(m, HumanMessage)]) + + # Lock language after N user messages (allows language switch early in conversation) + lang_lock_n = config.convstate.LOCK_LANGUAGE_AFTER_N_MESSAGES + if lang_lock_n > 0 and user_message_count >= lang_lock_n: + chain_logger.info(f"Language locked to '{self._stored_language}' (after {user_message_count} messages)") + current_language = self._stored_language + else: + detected_language = self._language_detector.detect_language(processed_query) + self._conversation_state['user_language'] = detected_language + + # Language validation + if detected_language in ['de', 'en']: + self._stored_language = detected_language + current_language = detected_language + else: + chain_logger.info("Invalid language detected.") + return LeadAgentQueryResponse( + response=LANGUAGE_FALLBACK_MESSAGE[current_language], + language=current_language, + processed_query=processed_query + ) + + # 4. Scope Check + scope_type = ScopeGuardian.check_scope(processed_query, current_language) + + if scope_type != 'on_topic': + chain_logger.info(f"Out-of-scope query detected: {scope_type}") + if scope_type == 'aggressive': + self._aggressive_violation_count += 1 + attempt_count = self._aggressive_violation_count + else: + self._scope_violation_counts[scope_type] = self._scope_violation_counts.get(scope_type, 0) + 1 + attempt_count = self._scope_violation_counts[scope_type] + + should_escalate, escalation_type = ScopeGuardian.should_escalate( + processed_query, scope_type, attempt_count + ) + + if should_escalate: + redirect_msg = ScopeGuardian.get_escalation_message(escalation_type, current_language) + else: + redirect_msg = ScopeGuardian.get_redirect_message(scope_type, current_language) + + self._conversation_history.append(HumanMessage(processed_query)) + self._conversation_history.append(AIMessage(redirect_msg)) + + return LeadAgentQueryResponse( + response=redirect_msg, + language=current_language, + processed_query=processed_query, + appointment_requested=False, + show_booking_widget=False, + ) + + # 5. Check if cached data already exists for this session + if config.cache.ENABLED: + cached_data = self._cache.get(query, current_language, self._user_id) + if cached_data and isinstance(cached_data, dict): + return LeadAgentQueryResponse( + response=cached_data["response"], + language=current_language, + appointment_requested=cached_data.get("appointment_requested", False), + show_booking_widget=cached_data.get("show_booking_widget", False), + relevant_programs=cached_data.get("relevant_programs", []), + ) + + + # 6. Preprocessing is finished - the agent has to answer the query + response = self._query_lead(query) + + if config.cache.ENABLED and response.should_cache: + self._cache.set( + key=query, + value={ + "response": response.response, + "appointment_requested": response.appointment_requested, + "show_booking_widget": response.show_booking_widget, + "relevant_programs": response.relevant_programs, + }, + language = current_language, + session_id = self._user_id, + ) + + return response + + + def _query_lead(self, preprocessed_query: str) -> LeadAgentQueryResponse: + """ + Phase 2: Execute agent. + Takes the ALREADY validated query from the preprocessing phase. + """ + # Reset scope-violation tracking + self._scope_violation_counts = {} + + response_language = self._stored_language + explicit_booking_intent = self._is_explicit_booking_intent(preprocessed_query) + booking_preference_follow_up = ( + self._conversation_state.get('handover_requested') is True + and self._previous_response_requested_booking_preferences() + and self._is_booking_preference_follow_up(preprocessed_query) + ) + + # 1. History Update + self._conversation_history.append(HumanMessage(preprocessed_query)) + + # 2. System instruction + language_instruction = SystemMessage(f"Respond in {get_language_name(response_language)} language.") + + # 3. Agent Call + structured_response = self._query( + agent=self._agents['lead'], + messages=self._conversation_history + [language_instruction], + ) + agent_response = structured_response.response + chain_logger.info(f"Is answer context dependent: {structured_response.is_context_dependent}") + chain_logger.info(f"Appointment Requested: {structured_response.appointment_requested}") + chain_logger.info(f"Show Booking Widget: {structured_response.show_booking_widget}") + chain_logger.info(f"Relevant Programs: {structured_response.relevant_programs}") + + # 4. Formatting + if config.chain.ENABLE_RESPONSE_CHUNKING: + formatted_response = ResponseFormatter.format_response( + agent_response, agent_type='lead', enable_chunking=True, language=response_language + ) + else: + formatted_response = ResponseFormatter.remove_tables(agent_response) + + formatted_response = ResponseFormatter.clean_response(formatted_response) + + confidence_fallback = False + # if config.chain.EVALUATE_RESPONSE_QUALITY: + # quality_evaluation: QualityEvaluationResult = self._quality_handler. \ + # evaluate_response_quality(preprocessed_query, formatted_response) + # + # chain_logger.info(f"Quality Score: {quality_evaluation.overall_score:1.2f}") + # + # if quality_evaluation.overall_score < config.chain.CONFIDENCE_THRESHOLD: + # confidence_fallback = True + # formatted_response = CONFIDENCE_FALLBACK_MESSAGE[response_language] + # chain_logger.info("Fallback Mechanism activated!") + + # Add to history + self._conversation_history.append(AIMessage(formatted_response)) + + # 6. Profiling + if config.convstate.TRACK_USER_PROFILE: + self._update_conversation_state(preprocessed_query, formatted_response) + + message_count = len([m for m in self._conversation_history if isinstance(m, HumanMessage)]) + if message_count % 5 == 0 or self._conversation_state.get('suggested_program'): + self._log_user_profile() + + formatted_response = ResponseFormatter.format_name_of_university(formatted_response, language=response_language) + + # Proactive booking offer. + # When the lead model signals booking readiness AND the assessment chain + # has identified a clear programme match, the booking widget is shown + # without waiting for an explicit "book"/"appointment" word from the user. + # The match comes from the existing profile-based assessment + # (suggested_program, set by _update_conversation_state above) or from + # relevant_programs returned by the lead model. Without this gate, the + # earlier user-led-only logic meant the widget effectively never fired. + clear_programme_match = ( + self._conversation_state.get('suggested_program') is not None + or bool(structured_response.relevant_programs) + ) + proactive_booking_offer = ( + clear_programme_match + and structured_response.show_booking_widget + ) + + booking_flow_requested = ( + explicit_booking_intent + or booking_preference_follow_up + or proactive_booking_offer + ) + appointment_requested = bool(booking_flow_requested) + show_booking_widget = bool( + booking_flow_requested and ( + structured_response.show_booking_widget + or self._response_commits_to_showing_booking_widget(formatted_response) + ) + ) + + if proactive_booking_offer and not (explicit_booking_intent or booking_preference_follow_up): + chain_logger.info( + "Proactive booking offer triggered " + f"(suggested_program={self._conversation_state.get('suggested_program')}, " + f"relevant_programs={structured_response.relevant_programs})" + ) + elif structured_response.appointment_requested and not booking_flow_requested: + chain_logger.info("Suppressed booking state because no programme match or booking intent was detected.") + elif booking_preference_follow_up and show_booking_widget: + chain_logger.info("Continuing active booking flow and showing booking widget for a preference follow-up.") + + return LeadAgentQueryResponse( + response = formatted_response, + language = response_language, + confidence_fallback = confidence_fallback, + should_cache = False if (confidence_fallback or appointment_requested or structured_response.is_context_dependent) else True, + processed_query = preprocessed_query, + appointment_requested = appointment_requested, + show_booking_widget = show_booking_widget, + relevant_programs = structured_response.relevant_programs + ) + + def _query(self, agent, messages: list, thread_id: str = None) -> StructuredAgentResponse: + try: + config = self._config.copy() + config['configurable']['thread_id'] = thread_id or 0 + + result: AIMessage = agent.invoke( + {"messages": messages}, + config=config, + context=AgentContext(agent_name=agent.name), + ) + response = result.get( + 'structured_response', + StructuredAgentResponse( + response=result['messages'][-1].text, + ) + ) + return response + except Exception as e: + error_msg = e.body['message'] if hasattr(e, 'body') else str(e) + chain_logger.error(f"Failed to invoke the agent: {error_msg}") + return StructuredAgentResponse( + response=QUERY_EXCEPTION_MESSAGE[self._stored_language], + ) diff --git a/src/rag/input_handler.py b/src/rag/input_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0780eb2d16c39398e8de021c1a53d24bfa28a8 --- /dev/null +++ b/src/rag/input_handler.py @@ -0,0 +1,147 @@ +""" +Input handler for processing and validating user messages. +Handles numeric inputs, validation, and interpretation. +""" +import re +from src.rag.utilclasses import ConversationState +from src.utils.logging import get_logger + +logger = get_logger("input_handler") + + +class InputHandler: + """Handles input validation and interpretation""" + + @staticmethod + def validate_and_normalize(message: str) -> str: + """ + Normalize and validate user input. + + Args: + message: Raw user input + + Returns: + Normalized message + """ + if not message: + return "" + + # Strip whitespace + normalized = message.strip() + + # Handle empty or very short inputs + if len(normalized) < 1: + return "" + + return normalized + + @staticmethod + def is_numeric_input(message: str) -> bool: + """ + Check if message is a standalone number. + + Args: + message: User input + + Returns: + True if message is just a number + """ + normalized = message.strip() + # Check if it's just digits (possibly with decimal) + return bool(re.match(r'^\d+(\.\d+)?$', normalized)) + + @staticmethod + def interpret_numeric_input( + message: str, + conversation_history: list + ) -> str: + """ + Interpret standalone numeric input based on conversation context. + + Args: + message: Numeric input (e.g., "5") + conversation_history: Recent conversation messages (LangChain message objects) + + Returns: + Interpreted message (e.g., "I have 5 years of experience") + """ + number = message.strip() + + # Look at recent messages for context + recent_context = "" + if len(conversation_history) > 0: + # Get last bot message + # Import here to avoid circular dependency + from langchain_core.messages import AIMessage + + for msg in reversed(conversation_history): + # Handle LangChain message objects + if isinstance(msg, AIMessage): + recent_context = msg.content.lower() if hasattr(msg, 'content') else "" + break + # Handle dictionary format (for backward compatibility) + elif isinstance(msg, dict) and msg.get("role") == "assistant": + recent_context = msg.get("content", "").lower() + break + + # Interpret based on context keywords + if any(keyword in recent_context for keyword in [ + "experience", "years", "worked", "arbeits", "erfahrung", "jahre" + ]): + logger.info(f"Interpreting numeric input '{number}' as years of experience") + return f"I have {number} years of work experience" + + elif any(keyword in recent_context for keyword in [ + "age", "old", "alter", "jahre alt" + ]): + logger.info(f"Interpreting numeric input '{number}' as age") + return f"I am {number} years old" + + elif any(keyword in recent_context for keyword in [ + "qualification", "degree", "bachelor", "master", "qualifikation" + ]): + logger.info(f"Interpreting numeric input '{number}' as qualification level") + # Interpret as degree type + level_map = { + "1": "I have a Bachelor's degree", + "2": "I have a Master's degree", + "3": "I have an MBA", + "4": "I have a doctorate/PhD" + } + return level_map.get(number, f"My qualification level is {number}") + + # Default: assume years of experience (most common) + logger.info(f"Interpreting numeric input '{number}' as years of experience (default)") + return f"I have {number} years of work experience" + + @staticmethod + def process_input( + message: str, + conversation_history: list + ) -> tuple[str, bool]: + """ + Process user input with validation and interpretation. + + Args: + message: Raw user input + conversation_history: Recent messages for context + + Returns: + Tuple of (processed_message, is_valid) + """ + # Normalize + normalized = InputHandler.validate_and_normalize(message) + + if not normalized: + return "", False + + # Check if numeric + if InputHandler.is_numeric_input(normalized): + interpreted = InputHandler.interpret_numeric_input( + normalized, + conversation_history + ) + return interpreted, True + + return normalized, True + \ No newline at end of file diff --git a/src/rag/language_detection.py b/src/rag/language_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..989c48bed4df4c9d0cdc92f0708e12273fb4d67c --- /dev/null +++ b/src/rag/language_detection.py @@ -0,0 +1,123 @@ +from pydantic import BaseModel, Field +from langchain_core.messages import HumanMessage +from src.rag.models import ModelConfigurator as modconf +from src.rag.prompts import PromptConfigurator as promptconf +import re + +from src.utils.logging import get_logger + +logger = get_logger('lang_detector') + +# Common short words for quick language detection (no LLM needed) +SHORT_WORDS_DE = { + 'ja', 'nein', 'danke', 'bitte', 'ok', 'gut', 'hallo', 'hi', 'hey', + 'genau', 'stimmt', 'klar', 'super', 'prima', 'toll', 'schön', + 'mehr', 'weniger', 'was', 'wie', 'wo', 'wann', 'warum', 'wer', + 'und', 'oder', 'aber', 'doch', 'noch', 'schon', 'jetzt', 'hier', + 'gerne', 'natürlich', 'sicher', 'vielleicht', 'also', 'ach', 'aha', +} +SHORT_WORDS_EN = { + 'yes', 'no', 'thanks', 'please', 'ok', 'okay', 'good', 'hello', 'hi', 'hey', + 'right', 'sure', 'great', 'nice', 'cool', 'fine', 'perfect', + 'more', 'less', 'what', 'how', 'where', 'when', 'why', 'who', + 'and', 'or', 'but', 'yet', 'now', 'here', 'there', + 'maybe', 'probably', 'definitely', 'certainly', 'alright', +} + +# Patterns for explicit language switch requests +SWITCH_TO_EN_PATTERNS = [ + 'in english', 'to english', 'switch to english', 'continue in english', + 'speak english', 'english please', 'prefer english', 'rather in english', + 'answer in english', 'respond in english', 'information in english', +] +SWITCH_TO_DE_PATTERNS = [ + 'auf deutsch', 'zu deutsch', 'in deutsch', 'deutsch bitte', 'lieber deutsch', + 'bitte deutsch', 'weiter auf deutsch', 'antworten auf deutsch', + 'in german', 'to german', 'switch to german', 'continue in german', + 'speak german', 'german please', 'prefer german', +] + +LANGUAGE_NEUTRAL_PROGRAM_PATTERNS = [ + r"emba", + r"emba hsg", + r"iemba", + r"iemba hsg", + r"international emba", + r"international executive mba", + r"emba x", + r"embax", +] + + +class LanguageDetectionResult(BaseModel): + language_code: str = Field(description="ISO language code (e.g., en, de, fa, ru) of the language in which the message is written") + + +class LanguageDetector: + def __init__(self) -> None: + self._model = modconf.get_language_detector_model() + self._model = self._model.with_structured_output(LanguageDetectionResult) + + def detect_explicit_switch_request(self, query: str) -> str | None: + """ + Detect if user explicitly requests a language switch. + Returns 'en', 'de', or None if no explicit switch requested. + """ + query_lower = query.lower() + + for pattern in SWITCH_TO_EN_PATTERNS: + if pattern in query_lower: + logger.info(f"Explicit language switch request detected: -> English") + return 'en' + + for pattern in SWITCH_TO_DE_PATTERNS: + if pattern in query_lower: + logger.info(f"Explicit language switch request detected: -> German") + return 'de' + + return None + + def _quick_detect_short_words(self, query: str) -> str | None: + """Quick detection for short inputs using word dictionary. Returns None if not detected.""" + words = query.lower().strip().split() + if len(words) > 3: + return None + + # Check each word against dictionaries + de_matches = sum(1 for w in words if w in SHORT_WORDS_DE) + en_matches = sum(1 for w in words if w in SHORT_WORDS_EN) + + if de_matches > en_matches: + logger.info(f"Quick detection: '{query}' -> German (dictionary match)") + return 'de' + elif en_matches > de_matches: + logger.info(f"Quick detection: '{query}' -> English (dictionary match)") + return 'en' + + return None + + def is_language_neutral_program_reference(self, query: str) -> bool: + """ + Return True when the query is only a programme name/reference and therefore + should not trigger a fresh language detection. + """ + normalized = re.sub(r"[^\w\s]", " ", query.casefold()) + normalized = re.sub(r"\s+", " ", normalized).strip() + return normalized in LANGUAGE_NEUTRAL_PROGRAM_PATTERNS + + def detect_language(self, query: str) -> str: + # Try quick detection for short inputs first + quick_result = self._quick_detect_short_words(query) + if quick_result: + return quick_result + + # Fall back to LLM for longer/ambiguous inputs + prompt = promptconf.get_language_detector_prompt(query) + messages = [HumanMessage(prompt)] + + try: + result = self._model.invoke(messages) + return result.language_code + except Exception as e: + logger.error(f"Failed to detect language: {e}") + return "" diff --git a/src/rag/middleware.py b/src/rag/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..528a78abfbaeb20e00d38f3983cfea61d7e715c8 --- /dev/null +++ b/src/rag/middleware.py @@ -0,0 +1,134 @@ +from datetime import datetime +from langchain.tools.tool_node import ToolCallRequest +from langchain.chat_models import BaseChatModel +from langchain.agents.middleware import ( + ModelRequest, + ModelResponse, + + wrap_model_call, + wrap_tool_call, +) +from langchain_core.messages import ToolMessage +from openai import ( + BadRequestError, + OpenAIError, + InternalServerError, + NotFoundError, + RateLimitError, +) + +from src.config import config +from src.rag.utilclasses import AgentContext +from src.utils.logging import get_logger + +model_logger = get_logger('chain_model_call') +tool_logger = get_logger('chain_tool_call') + +class AgentChainMiddleware: + _tool_wrapper_middleware = None + _model_wrapper_middleware = None + + + @classmethod + def get_tool_wrapper(cls): + if cls._tool_wrapper_middleware: + return cls._tool_wrapper_middleware + + cls._tool_wrapper_middleware = wrap_tool_call(cls._tool_call_wrapper) + tool_logger.info(f"Initialized tool call wrapper with call inspection") + return cls._tool_wrapper_middleware + + + @classmethod + def get_model_wrapper(cls): + if cls._model_wrapper_middleware: + return cls._model_wrapper_middleware + + cls._model_wrapper_middleware = wrap_model_call(cls._model_call_wrapper) + model_logger.info(f"Initialized model call wrapper with maximum of {config.chain.MAX_RETRIES} retry attempts") + return cls._model_wrapper_middleware + + + @staticmethod + def _model_call_wrapper(request: ModelRequest, handler): + context: AgentContext = request.runtime.context + model: BaseChatModel = request.model + model_logger.info(f"{context.agent_name} is attempting to call model '{model.model_name}'...") + for attempt in range(1, config.chain.MAX_RETRIES+1): + try: + response: ModelResponse = handler(request) + model_logger.info(f"{context.agent_name} recieved response from model after {attempt} attempt{'s' if attempt > 1 else ''}") + result = response.result[0] + metadata = result.response_metadata + # Check if any errors occured during tool call execution. + # Some errors might be fatal, making the model unusable in the agent chain + if hasattr(result, 'invalid_tool_calls') and result.invalid_tool_calls: + for invalid_call in result.invalid_tool_calls: + fail_reason = invalid_call.get('error', 'Unknown').replace('\n', '') + model_logger.warning(f"Failed tool call: {invalid_call['name']}, error: {fail_reason}, retrying the call...") + if 'JSONDecodeError' in fail_reason: + model_logger.error(f"Model does not support current tool call architecture! Switching to the fallback model...") + raise Exception("Unsupported model") + elif not result.content and metadata['finish_reason'] != 'tool_calls': + model_logger.warning(f"Model returned an empty response, reason - {metadata['finish_reason']}! Retrying the call...") + else: + return response + except OpenAIError as e: + match e: + case InternalServerError(): + model_logger.warning(f"[{e.code}] Internal difficulties on the provider side, retrying the call...") + case RateLimitError(): + model_logger.warning(f"[{e.code}] Model is temporary rate limited, retrying the call...") + case NotFoundError(): + model_logger.error(f"[{e.code}] Model cannot be used in the chain, reason: {e.body['message']}") + raise e + case BadRequestError(): + model_logger.error(f"[400] Bad request: {e.body['message']}") + raise e + + if attempt == config.chain.MAX_RETRIES: + model_logger.warning(f"Failed to recieve response from model '{model.model_name}' after {config.chain.MAX_RETRIES} attempt{'s' if attempt > 1 else ''}, reason: {e.body['message']}") + model_logger.info(f"Switching to the fallback model...") + raise e + except Exception as e: + model_logger.error(f"An error occured during model call (possibly backend side): {e}") + raise e + + errormsg = f"{context.agent_name} failed to perform the model call due to unknown reason!" + model_logger.error(errormsg) + raise RuntimeError(errormsg) + + + @staticmethod + def _tool_call_wrapper(request: ToolCallRequest, handler): + context: AgentContext = request.runtime.context or AgentContext(agent_name="Agent") + + tool_call = request.tool_call + tool_logger.info(f"{context.agent_name} is calling tool: {tool_call['name']} with tool call id {tool_call['id']}") + try: + response = handler(request) + tool_logger.info(f"Recieved response from tool call {tool_call['id']}") + if not response.content: + tool_logger.warning("Tool returned nothing! This might be an issue on the tool side.") + return response + except Exception as e: + tool_logger.error(f"Failed to use tool {tool_call['name']} with id {tool_call['id']}") + artifact = { + 'error_type': type(e).__name__, + 'error_message': str(e), + 'tool_name': tool_call['name'], + 'tool_args': tool_call['args'], + 'timestamp': datetime.now().isoformat(), + } + + import json + error_content = f"""Failed to use tool: {str(e)} + +Error details: +{json.dumps(artifact, indent=2)}""" + + return ToolMessage( + content=error_content, + tool_call_id=tool_call['id'], + artifact=artifact, + ) diff --git a/src/rag/models.py b/src/rag/models.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7c54b0a33c4dd22a95a8a264724773beb18e07 --- /dev/null +++ b/src/rag/models.py @@ -0,0 +1,185 @@ +from langchain.chat_models import BaseChatModel +from src.config import config + +from src.utils.logging import get_logger + +logger = get_logger("model_config") + +class ModelConfigurator: + _main_model_instance: BaseChatModel = None + _subagent_model_instance: BaseChatModel = None + _fallback_models_instances: list[BaseChatModel] = None + _summarization_model_instance: BaseChatModel = None + _confidence_scoring_model_instance: BaseChatModel = None + _language_detector_model_instance: BaseChatModel = None + + @classmethod + def get_language_detector_model(cls) -> BaseChatModel: + if cls._confidence_scoring_model_instance: + return cls._confidence_scoring_model_instance + try: + from langchain_openai import ChatOpenAI + cls._language_detector_model_instance = ChatOpenAI( + model='gpt-4o-mini', + openai_api_key=config.llm.get_api_key(), + max_tokens=3072, + temperature=0.00, + timeout=60, + request_timeout=60, + ) + logger.info(f"Initialized language detection model") + return cls._language_detector_model_instance + except Exception as e: + logger.error(f"Failed to initialize language detection model: {e}") + raise e + + @classmethod + def get_confidence_scoring_model(cls) -> BaseChatModel: + if cls._confidence_scoring_model_instance: + return cls._confidence_scoring_model_instance + + try: + from langchain_openai import ChatOpenAI + cls._confidence_scoring_model_instance = ChatOpenAI( + model='gpt-4o-mini', + openai_api_key=config.llm.get_api_key(), + max_tokens=3072, + temperature=0.00, + timeout=60, + request_timeout=60, + ) + logger.info(f"Initialized confidence scoring model") + return cls._confidence_scoring_model_instance + except Exception as e: + logger.error(f"Failed to initialize confidence scoring model: {e}") + raise e + + + @classmethod + def get_summarization_model(cls) -> BaseChatModel: + if cls._summarization_model_instance: + return cls._summarization_model_instance + + try: + # Add custom summarization model initialization here if needed + cls._summarization_model_instance = cls.get_main_agent_model() + logger.info(f"Initialized summarization model '{config.llm.LLM_PROVIDER.name}:{config.llm.get_default_model()}'") + return cls._summarization_model_instance + except Exception as e: + logger.error(f"Failed to initialize the summarization model: {e}") + raise e + + @classmethod + def get_subagent_model(cls) -> BaseChatModel: + if cls._subagent_model_instance: + return cls._subagent_model_instance + + from langchain_openai import ChatOpenAI + cls._subagent_model_instance = ChatOpenAI( + model='gpt-5.1-instant', + openai_api_key=config.llm.get_api_key(), + max_tokens=3072, + temperature=0.01, + timeout=60, + request_timeout=60, + ) + return cls._subagent_model_instance + + + @classmethod + def get_main_agent_model(cls) -> BaseChatModel: + """Initialize the language model based on config.""" + if cls._main_model_instance: + return cls._main_model_instance + + try: + cls._main_model_instance = cls._initialize_model( + provider=config.llm.LLM_PROVIDER, + model=config.llm.get_default_model() + ) + logger.info(f"Initialized main agent model '{config.llm.LLM_PROVIDER.name}:{config.llm.get_default_model()}'") + return cls._main_model_instance + except Exception as e: + logger.error(f"Failed to initialize the main agent model for provider '{config.llm.LLM_PROVIDER.name}': {e}") + raise e + + + @classmethod + def get_fallback_models(cls) -> list[BaseChatModel]: + if cls._fallback_models_instances != None: + return cls._fallback_models_instances + + cls._fallback_models_instances = cls._initialize_fallback_models() + if len(cls._fallback_models_instances) == 0: + logger.warning("No fallback models were initialized! Response generation may result in unexpected errors!") + return cls._fallback_models_instances + + + @classmethod + def _initialize_fallback_models(cls) -> list[BaseChatModel]: + fallback_models_instances = [] + for fallback_provider, fallback_model in config.llm.get_fallback_models().items(): + try: + fallback_model_instance = cls._initialize_model( + provider=fallback_provider, + model=fallback_model, + ) + logger.info(f"Initialized fallback model '{fallback_provider.name}:{fallback_model}'") + fallback_models_instances.append(fallback_model_instance) + except Exception as e: + logger.error(f"Failed to initialize the fallback model {fallback_provider.name}:{fallback_model}: {e}; skipping...") + return fallback_models_instances + + + @classmethod + def _initialize_model(cls, provider, model: str) -> BaseChatModel: + try: + match provider.name: + case 'groq': + from langchain_groq import ChatGroq + return ChatGroq( + model=model, + groq_api_key=config.llm.get_api_key(), + temperature=0.01, + ) + case ( 'open_router:openai' + | 'open_router:alibaba' + | 'open_router:nvidia' + | 'open_router:meituan'): + from langchain_openai import ChatOpenAI + return ChatOpenAI( + model=model, + base_url=config.llm.OPEN_ROUTER_BASE_URL, + api_key=config.llm.get_api_key(), + temperature=0.01, + ) + case 'open_router:deepseek': + from langchain_deepseek import ChatDeepSeek + return ChatDeepSeek( + model=model, + api_key=config.llm.OPEN_ROUTER_API_KEY, + api_base=config.llm.OPEN_ROUTER_BASE_URL, + ) + case 'openai': + from langchain_openai import ChatOpenAI + return ChatOpenAI( + model=model, + openai_api_key=config.llm.get_api_key(), + max_tokens=3072, + temperature=0.01, + timeout=60, + request_timeout=60, + ) + case 'ollama': + from langchain_ollama import ChatOllama + return ChatOllama( + model=model, + base_url=config.llm.OLLAMA_BASE_URL, + temperature=0.01, + reasoning=config.llm.get_reasoning_support(), + num_predict=2048, + ) + case _: + raise ValueError(f"Unsupported LLM provider: {provider.name}") + except Exception as e: + raise e diff --git a/src/rag/prompts.py b/src/rag/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..76465865ddc50948273089bf1c6766617e041644 --- /dev/null +++ b/src/rag/prompts.py @@ -0,0 +1,265 @@ +class PromptConfigurator: + # 1. BASE PROMPT (Shared by all program sub-agents) + _BASE_PROGRAM_PROMPT = """You are the specialized support agent for {program_full_name}. + +CRITICAL: Call retrieve_context(query, program, language) FIRST and only ONCE, then answer using the retrieved results combined with YOUR SPECIFIC EXPERTISE below. The programme details listed under YOUR SPECIFIC EXPERTISE (tuition, eligibility, format, etc.) are AUTHORITATIVE — always state them directly and concretely when asked. + +When the user asks about distinctiveness, USPs, "what is special", "why this programme", rankings, alumni network, or other selling points, ground the answer in concrete facts from retrieve_context(). Cite specific rankings, alumni network attributes, and programme features that appear in the retrieved content. Do not paraphrase retrieved facts into generic phrasing. + +YOUR SPECIFIC EXPERTISE: +{program_specifics} + +BRANDING & NAMING RULES: +- Institution Name: Always use "**{university_name}**". +- Strict Spelling: "**St.Gallen**" (NEVER "St. Gallen" with a space). +- "HSG" Usage: Only use "HSG" if it is part of the official program name (e.g., "EMBA HSG"). If the context refers to the university as "HSG", replace it with "{university_name}". + +RESPONSE FORMAT: +- Answer the question directly. No opening pleasantries or filler. +- Do NOT open with paraphrased validation of the user's last message ("You are absolutely right", "Thank you for sharing", "For your situation, X years in Y..."). The user knows what they wrote; restating it adds nothing. +- Profile data informs the answer. It is not narrated back. Reference user context at most once when introducing a recommendation, never as a recurring opener. +- Use short paragraphs by default. Tables are forbidden. +- Use bullet points or numbered lists only when listing 2 or more items. A single point is written as a sentence, not as "1." or "•". +- If the user requests N items ("give me 3 reasons"), deliver all N in this same response. Do not truncate the list and offer to continue. +- Never end with "Would you like me to continue with more details?" or any equivalent. Either complete the answer or state the limit upfront. +- When the user asks for more information on a topic already discussed ("tell me more", "and?", "weiter", "more details", "noch mehr"), deliver substantively new content — facts, angles, or specifics not already in your earlier responses. Never repeat or paraphrase what you already said. Call retrieve_context() again with a refined query if needed. If no genuinely new content is available, say so directly rather than restating prior content. +- Use complete sentences and maintain a professional, university-level tone. In English, use professional British English. +- Avoid overly casual phrases such as "Great to meet you" or "If you'd like, tell me...". +- Target around 100 words. The budget is for substance — filler counts against it. + +PROGRAMME POSITIONING WHEN INTEREST IS ESTABLISHED: +- If the user has clearly expressed interest in {program_full_name}, answer the concrete question first, then add positive value framing for that programme. Use specific facts from retrieve_context() — rankings, alumni network, distinctive programme features — not generic phrasing. +- Stay credible and grounded. Do not use hype-heavy claims such as "best", "world-leading", "perfect", or "guaranteed" unless the retrieved source material explicitly supports them. +- For early factual questions such as price, duration, format, or deadlines, do not force promotional language unless the user's wording shows clear programme interest. + +PRICING RULES: +- Only provide pricing for YOUR specific programme ({program_full_name}). +- NEVER combine prices from different programmes into a range. +- If YOUR programme has published application deadlines with different fees, mention the deadline-based fee schedule when the user asks about price or tuition. +- If YOUR programme only has one published tuition figure, give that flat tuition and do NOT invent a tuition fee reduction schedule. +- Use the term "tuition fee reduction" consistently. +- Always clarify what is INCLUDED vs NOT INCLUDED in tuition. + +RULES: +- Answer only in {selected_language} +- IMPORTANT: Translate ALL terms into {selected_language}. NEVER leave English terms untranslated in a German response. Key translations for German: + - "tuition fee reduction" → "Studiengebührenreduktion" + - "tuition" → "Studiengebühr(en)" + - "included in tuition" → "in den Studiengebühren enthalten" + - "not included" → "nicht enthalten" + - "payable in instalments" → "zahlbar in Raten" + - "application deadline" → "Bewerbungsfrist" + - "deadline-based fee" → "fristabhängige Studiengebühr" +- Use context from retrieve_context() AND your programme-specific expertise above. +- Never make up details beyond what is listed in YOUR SPECIFIC EXPERTISE or retrieved context. +- If neither source has the answer, acknowledge limitation.""" + + # 2. PROGRAM SPECIFIC DEFINITIONS + _PROGRAM_DEFINITIONS = { + 'emba': { + 'full_name': "Executive MBA HSG (EMBA)", + 'specifics': """- FOCUS: General Management, Leadership, DACH Region Business. +- TARGET AUDIENCE: German-speaking executives/managers in DACH region. +- LANGUAGE: German (strong working knowledge required). +- START DATE: 14 September 2026. +- FORMAT: Part-time ONLY (no full-time option). Duration: 18 months, extendable up to 48 months. +- LOCATIONS: St.Gallen, Switzerland; Belgium; elective course location(s) vary. +- STRUCTURE: 9 core courses plus 5 elective courses. Total: 14 weeks on campus plus Capstone project. +- KEY DIFFERENTIATOR: Deep local network, general management foundation in German, strong DACH focus. +- VALUE PROPOSITION: A particularly attractive option for German-speaking leaders who want to deepen general-management capability, strengthen practical leadership judgement, and build a relevant executive peer network in the DACH business context. +- POSITIVE FRAMING WHEN INTEREST IS CLEAR: Emphasise the combination of HSG management depth, practical leadership development, regional relevance, and a strong German-speaking executive environment. +- TUITION: CHF 77,500. +- INCLUDED IN TUITION: Tuition fees, course materials, most on-site meals and refreshments. +- NOT INCLUDED: Accommodation during modules, travel expenses to modules, individual expenses. +- IMPORTANT: Accommodation is NOT included (NEVER say it is included). +- ELIGIBILITY: University degree, 5+ years work experience, 3+ years leadership experience (direct or indirect). +- If discussing pricing, state the published tuition of CHF 77,500. Do NOT mention a tuition fee reduction schedule unless retrieved context explicitly provides one.""" + }, + 'iemba': { + 'full_name': "International Executive MBA HSG (IEMBA)", + 'specifics': """- FOCUS: Solid management content with a strong international approach. +- TARGET AUDIENCE: Executives working in global roles or aspiring to international careers. +- LANGUAGE: English (strong working knowledge required). +- START DATE: 24 August 2026. +- FORMAT: Part-time ONLY (no full-time option). Duration: 18 months. Modules in Switzerland and internationally. +- LOCATIONS: Costa Rica, Tokyo, Japan, New York City, St.Gallen, Switzerland, Beijing, China, UC Berkeley, USA, UC Irvine, USA, Italy, South Africa, Spain, plus elective course location(s) vary. +- STRUCTURE: 10 core courses plus 4 elective courses. Total: 10 weeks on campus, 4 weeks abroad, plus thesis. +- KEY DIFFERENTIATOR: International cohort, modules that allow students to study both in Switzerland and abroad. +- VALUE PROPOSITION: A strong option for leaders who want to broaden their management perspective internationally, learn with a global cohort, and connect leadership development with exposure to different business environments. +- POSITIVE FRAMING WHEN INTEREST IS CLEAR: Emphasise international exposure, the global peer group, modules across different regions, and the value of building leadership confidence beyond a single local market. +- TUITION: CHF 85,000. +- INCLUDED IN TUITION: Tuition fees, course materials, most on-site meals and refreshments. +- NOT INCLUDED: Accommodation during modules, travel expenses to modules, individual expenses. +- IMPORTANT: Accommodation is NOT included (NEVER say it is included). +- ELIGIBILITY: University degree, 5+ years work experience, 3+ years leadership experience (direct or indirect). +- If discussing pricing, state the published tuition of CHF 85,000. Do NOT mention a tuition fee reduction schedule unless retrieved context explicitly provides one.""" + }, + 'embax': { + 'full_name': "emba X (ETH Zurich & University of St.Gallen Joint Degree Programme)", + 'specifics': """- FOCUS: Programme topics include Technology, International Management, Leadership, Business Innovation, and Social Responsibility. +- TARGET AUDIENCE: Leaders bridging the gap between business and technology. Tech backgrounds are an asset. +- LANGUAGE: English (fluency required). +- FORMAT: Part-time ONLY (no full-time option). Blended format with online modules plus modules in Zurich and St.Gallen, Switzerland. +- START / END: The supplied programme material states January 2027 to July 2028, while the application section states the programme starts in February 2027. If asked for the exact start month, say the published material indicates an early-2027 start and admissions should confirm the exact date. +- DURATION: 18 months. +- LOCATIONS: Zurich and St.Gallen, Switzerland. +- TIME COMMITMENT: 56 days on campus, 2 days online, and 42 days out of office. +- KEY DIFFERENTIATOR: Joint Degree Programme from ETH Zurich and the University of St.Gallen. Graduates get access to BOTH ETH Zurich and University of St.Gallen alumni networks in one fully integrated programme experience. +- VALUE PROPOSITION: Develop socially responsible leadership at the intersection of leadership and technology, with an evolving curriculum, strong Swiss business network access, and a holistic development approach. +- POSITIVE FRAMING WHEN INTEREST IS CLEAR: Emphasise the distinctive ETH Zurich and University of St.Gallen joint-degree positioning, the business-and-technology leadership intersection, transformation and innovation relevance, the Personal Development Programme, and access to both alumni networks. +- CURRICULUM ELEMENTS: Essential courses, faculty-directed immersion modules with real action plans, emba X Projects, and a tailored Personal Development Programme with peer-to-peer coaching. +- PERSONAL DEVELOPMENT PROGRAMME (PDP): Builds competencies in self-leadership, team and organisation leadership, and integrative leadership. +- TUITION / DEADLINES: First application deadline 31 August 2026: CHF 99,000. Final application deadline 31 October 2026: CHF 110,000. Tuition is payable in four instalments. +- INCLUDED IN TUITION: Tuition fees, course materials, most on-site meals and refreshments. +- NOT INCLUDED: Accommodation during modules, travel expenses to modules, individual expenses. +- IMPORTANT: Accommodation is NOT included (NEVER say it is included). There are NO international study trips. Keep emba X distinct from IEMBA's international modules and global orientation. +- ELIGIBILITY: Recognised academic degree (undergraduate or above), 10+ years work experience, 5+ years leadership experience, fluency in English. +- For tuition fee reduction details beyond the published deadlines, or for loan options, direct the user to speak with the emba X admissions team. +- TECH BACKGROUND: Proactively mention emba X to users with software/tech backgrounds and highlight the Joint Degree Programme, both alumni networks, the Personal Development Programme, and the leadership-and-technology focus.""" + } + } + + # 3. LEAD AGENT PROMPT + _LEAD_SYSTEM_PROMPT = """You are an Executive Education Advisor for the HSG Executive MBA programmes (EMBA HSG, IEMBA HSG, emba X) at the {university_name}. Your job is orchestration: route programme-specific questions to the relevant sub-agent, manage booking, handle ambiguity, and enforce tone. Do not answer programme content yourself. + +BRANDING & NAMING: +- Use "**{university_name}**". Spell "**St.Gallen**" without a space. +- "HSG" only inside official programme names (e.g. "EMBA HSG"). Refer to the institution as "{university_name}". + +TOOL ROUTING (mandatory): +- Any substantive question about a specific programme — content, USPs, ranking, fit, structure, distinctiveness, "why HSG", "what is special", "tell me more" — MUST be answered by calling the relevant sub-agent. Never answer programme-specific content from this prompt. + - `call_emba_agent` → EMBA HSG (German DACH programme). + - `call_iemba_agent` → IEMBA HSG (English international programme). + - `call_embax_agent` → emba X (joint degree with ETH Zurich, business + technology focus). +- Decision heuristic for routing when the user has not named a programme: + - German-speaking + DACH focus → EMBA HSG. + - English + international focus → IEMBA HSG. + - Technology, innovation, transformation, or tech-leadership focus → emba X. + - Tech background is a routing signal toward emba X. +- You answer directly only for: ambiguity clarification, light comparisons across all three programmes, eligibility filtering, booking handling, and visa/cross-sell redirects. + +AMBIGUITY: +- If the user asks about "the EMBA" or "the programme" without specifying which one, ask: "Are you interested in the **German-speaking EMBA HSG**, the **International EMBA (IEMBA)**, or the **emba X**?" + +ELIGIBILITY: +- EMBA HSG and IEMBA: university degree, 5+ years work experience, 3+ years leadership (direct or indirect). +- emba X: recognised degree, 10+ years work experience, 5+ years leadership. +- Language: EMBA HSG requires strong German; IEMBA and emba X require strong English. +- Degree and leadership are mandatory; never imply they are optional. +- If the profile clearly does not meet requirements: state this politely, do not coach the user on "how to prepare", and provide https://www.mba.unisg.ch/ for alternatives. +- Format: all programmes are PART-TIME ONLY. Never ask "part-time vs full-time". + +BOOKING & APPOINTMENTS: +- Set `appointment_requested=True` and `show_booking_widget=True` when EITHER: + (a) the user explicitly asks to book, schedule, see appointment slots, speak with admissions/an advisor, or accepts a previous consultation offer, OR + (b) a programme has been clearly identified for the user AND the user signals readiness for a personal consultation (e.g. asks "is this right for me?", "would HSG suit me?", "does this fit my profile?", or expresses commitment after a recommendation). +- Routine informational turns keep both flags `False`. +- When booking is on, populate `relevant_programs` from: 'emba' (advisor Cyra von Müller), 'iemba' (advisor Kristin Fuchs), 'emba_x' (advisor Teyuna Giger). Multiple programmes if the user is deciding between them. Empty if undecided. +- When showing the widget, the wording should be explicit: "I can show you appointment options with [Advisor Name] for the [Programme Name]." Mention that contact details and slots are shown below only when `show_booking_widget=True`. +- Do not generate URLs or fake buttons. Never say you cannot book appointments. + +VISA / RELOCATION: +- Redirect: "For visa and permit questions, please contact our admissions team." +- Do not ask if the user plans to relocate. + +CROSS-SELL: +- For users who do not fit any of the three programmes, mention HSG alternatives at https://op.unisg.ch/en/ or https://www.mba.unisg.ch/. Do not recommend non-HSG programmes. + +POSITIONING: +- Match framing to the conversation stage. Early discovery: balanced and factual. Expressed interest in one programme: answer first, then add positive value framing for that programme. +- Avoid hype words ("best", "world-class", "perfect", "guaranteed") unless the sub-agent's retrieved content explicitly supports them. + +TONE & FORMAT: +- Answer the question directly. No opening pleasantries or filler. +- Do NOT open with paraphrased validation of the user's last message ("You are absolutely right", "Thank you for sharing", "For your situation, X years in Y..."). The user knows what they wrote; restating it adds nothing. +- Profile data informs the answer. It is not narrated back. Reference user context at most once when introducing a recommendation, never as a recurring opener. +- Use short paragraphs by default. Tables are forbidden. Bullets/numbered lists only when listing 2 or more items. A single point is a sentence, not "1." or "•". +- If the user requests N items ("give me 3 reasons"), deliver all N in this same response. Do not truncate and offer to continue. "Would you like me to continue with more details?" and equivalents are forbidden. +- When the user asks for more information on a topic already discussed ("tell me more", "weiter", "and?", "more details"), route to the relevant sub-agent again so fresh content is retrieved. Never repeat or paraphrase your earlier response. If the sub-agent returns no genuinely new content, say so directly. +- Bold key facts (**programme names**, **dates**, **costs**). +- Target around 100 words. The budget is for substance — filler counts against it. +- Professional, university-level tone. Complete sentences. In English, professional British English. Avoid casual phrasing like "Great to meet you" or "If you'd like, tell me...". + +LANGUAGE: +- Answer in the user's language. In German responses, never leave English terms untranslated. Key translations: + "tuition fee reduction" → "Studiengebührenreduktion", "tuition" → "Studiengebühr(en)", "included in tuition" → "in den Studiengebühren enthalten", "not included" → "nicht enthalten", "application deadline" → "Bewerbungsfrist". + +CONTEXT FLAGS: +- Set `is_context_dependent=True` for: eligibility, recommendations, comparisons referencing earlier turns, anything using extracted profile data, anything influenced by conversation history. +- Set `is_context_dependent=False` for static facts (prices, durations, deadlines, structure), definitions, and publicly available information that does not vary by user. + +GENERAL: +- Never discuss competitor MBA programmes outside HSG/ETH. +- Do not provide detailed financial planning. +- Never say accommodation is included — it is not included in any programme.""" + + _SUMMARIZATION_PROMPT = """Summarize the conversation concisely: + 1. Topics discussed + 2. User's experience/career goals + 3. Programs mentioned + 4. Next steps + + Keep to 100 words max.""" + + _SUMMARY_PREFIX_PROMPT = "Conversation Summary:" + + _QUALITY_SCORING_PROMPT = """Rate the response (0.0-1.0) on: format, context, pricing, scope, and rules. + User query: {query} + AI response: {response}""" + + _LANGUAGE_DETECTOR_PROMPT = """Detect the language (ISO code). User query: {query}""" + + @classmethod + def get_language_detector_prompt(cls, query): + return cls._LANGUAGE_DETECTOR_PROMPT.format(query=query) + + @classmethod + def get_summarization_prompt(cls): + return cls._SUMMARIZATION_PROMPT + + @classmethod + def get_summary_prefix(cls): + return cls._SUMMARY_PREFIX_PROMPT + + @classmethod + def get_configured_agent_prompt(cls, agent: str, language: str = 'en'): + # 1. Determine Language Settings + if language == 'de': + selected_language = 'German' + university_name = 'Universität St.Gallen' + else: + selected_language = 'British English' + university_name = 'University of St.Gallen' + + agent_key = agent.lower().replace(" ", "") + + # 2. Configure Lead Agent + if agent_key == 'lead': + return cls._LEAD_SYSTEM_PROMPT.format( + university_name=university_name + ) + + # 3. Configure Program Agents + prog_def = cls._PROGRAM_DEFINITIONS.get(agent_key) + + if prog_def: + return cls._BASE_PROGRAM_PROMPT.format( + program_full_name=prog_def['full_name'], + program_specifics=prog_def['specifics'], + selected_language=selected_language, + university_name=university_name, + program_name=agent.upper() + ) + else: + # Fallback + return cls._BASE_PROGRAM_PROMPT.format( + program_full_name="HSG Executive Education", + program_specifics="- General HSG Program Support", + selected_language=selected_language, + university_name=university_name, + program_name="GENERAL" + ) + + @classmethod + def get_quality_scoring_prompt(cls, query: str, response: str) -> str: + return cls._QUALITY_SCORING_PROMPT.format(query=query, response=response) diff --git a/src/rag/quality_score_handler.py b/src/rag/quality_score_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d843067b3c21ff193a412b4afc80deb1559c8029 --- /dev/null +++ b/src/rag/quality_score_handler.py @@ -0,0 +1,57 @@ +from pydantic import BaseModel, Field +from langchain_core.messages import HumanMessage +from langsmith import Client +from src.rag.models import ModelConfigurator as modconf +from src.rag.prompts import PromptConfigurator as promptconf + +from src.utils.logging import get_logger + +from time import perf_counter + +logger = get_logger('quality_score_handler') + +class QualityEvaluationResult(BaseModel): + """Result of response quality evaluation.""" + + overall_score: float = Field(description='Overall response rating') + format_adherence_score: float = Field(description='Format adherence score') + context_awareness_score: float = Field(description='Context awareness score') + pricing_adherence_score: float = Field(description='Pricing guidelines adherence score') + scope_compliance_score: float = Field(description='Scope compliance score') + general_rules_score: float = Field(description='General rules score') + comment: str = Field(description='Brief explanation') + + +class QualityScoreHandler: + def __init__(self) -> None: + self._smith_client = Client() + self._model = modconf.get_confidence_scoring_model() + self._model = self._model.with_structured_output(QualityEvaluationResult) + + + def evaluate_response_quality(self, query: str, response: str) -> QualityEvaluationResult: + prompt = promptconf.get_quality_scoring_prompt(query, response) + messages = [HumanMessage(prompt)] + + try: + time_start = perf_counter() + logger.info("Evaluating the response quality...") + evaluation: QualityEvaluationResult = self._model.invoke(messages) + time_elapsed = perf_counter() - time_start + logger.info(f"Finished confidence evaluation in {time_elapsed:1.3} sec") + + evaluation.overall_score = sum([ + evaluation.format_adherence_score, + evaluation.context_awareness_score, + evaluation.pricing_adherence_score, + evaluation.scope_compliance_score, + evaluation.general_rules_score, + ]) / 5.0 + + logger.info(f"- scoring: {evaluation.overall_score:1.2f}") + logger.info(f"- comment: {evaluation.comment}") + + return evaluation + except Exception as e: + logger.error(f"Failed to evaluate the response's confidence: {e}") + return QualityEvaluationResult() diff --git a/src/rag/response_formatter.py b/src/rag/response_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..716be4b328472a0a06ab367437a24935295cd29b --- /dev/null +++ b/src/rag/response_formatter.py @@ -0,0 +1,186 @@ +""" +Response formatter for handling long responses and table formatting. +Ensures responses are mobile-friendly and appropriately sized. +""" +import re +from src.config import config +from src.utils.logging import get_logger + +logger = get_logger("response_formatter") + + +CONTINUATION_PROMPT = { + 'en': "*Would you like me to continue with more details?*", + 'de': "*Möchten Sie, dass ich mit weiteren Details fortfahre?*" +} + + +class ResponseFormatter: + """Formats agent responses for optimal display""" + + @staticmethod + def count_words(text: str) -> int: + """Count words in text""" + words = text.split() + return len(words) + + @staticmethod + def remove_tables(text: str) -> str: + """ + Convert markdown tables to bullet point lists. + Tables don't display well on mobile devices. + + Args: + text: Response text potentially containing tables + + Returns: + Text with tables converted to bullet points + """ + # Pattern to match markdown tables + table_pattern = r'\|[^\n]+\|\n\|[-:\s|]+\|\n(\|[^\n]+\|\n)+' + + def table_to_bullets(match): + table_text = match.group(0) + lines = [line.strip() for line in table_text.split('\n') if line.strip()] + + if len(lines) < 3: # Not a valid table + return table_text + + # Extract headers (first line) + headers = [cell.strip() for cell in lines[0].split('|') if cell.strip()] + + # Skip separator line (second line) + # Process data rows + bullet_points = [] + for line in lines[2:]: + cells = [cell.strip() for cell in line.split('|') if cell.strip()] + if cells and len(cells) == len(headers): + # Create bullet point from row + row_text = ", ".join([ + f"**{headers[i]}**: {cells[i]}" + for i in range(len(cells)) + if cells[i] + ]) + bullet_points.append(f"• {row_text}") + + return "\n".join(bullet_points) + + # Replace tables with bullet points + formatted = re.sub(table_pattern, table_to_bullets, text) + + if formatted != text: + logger.info("Converted table to bullet points for mobile-friendly display") + + return formatted + + @staticmethod + def chunk_response( + text: str, + max_words: int = config.chain.MAX_RESPONSE_WORDS_LEAD, + language: str = 'en' + ) -> tuple[str, str | None]: + """ + Split long response into current response and continuation. + + Args: + text: Full response text + max_words: Maximum words for current response + language: Language code ('en' or 'de') for continuation prompt + + Returns: + Tuple of (current_response, continuation_or_none) + """ + word_count = ResponseFormatter.count_words(text) + + if word_count <= max_words: + return text, None + + # Need to chunk — preserve line structure (markdown formatting) + logger.info(f"Response has {word_count} words, chunking to {max_words} words") + + lines = text.split('\n') + current_lines = [] + current_word_count = 0 + + for line in lines: + line_words = len(line.split()) if line.strip() else 0 + if current_word_count + line_words > max_words and current_lines: + break + current_lines.append(line) + current_word_count += line_words + + current = '\n'.join(current_lines) + continuation = '\n'.join(lines[len(current_lines):]) + + # Add continuation prompt in the correct language + continuation_msg = CONTINUATION_PROMPT.get(language, CONTINUATION_PROMPT['en']) + current += f"\n\n{continuation_msg}" + + return current, continuation + + @staticmethod + def format_response( + text: str, + agent_type: str = 'lead', + enable_chunking: bool = True, + language: str = 'en' + ) -> str: + """ + Format response: remove tables and handle length. + + Args: + text: Raw response text + agent_type: 'lead' or 'subagent' (determines max length) + enable_chunking: Whether to chunk long responses + language: Language code ('en' or 'de') for any generated text + + Returns: + Formatted response text + """ + # Remove tables + formatted = ResponseFormatter.remove_tables(text) + + # Determine max words + max_words = ( + config.chain.MAX_RESPONSE_WORDS_LEAD + if agent_type == 'lead' + else config.chain.MAX_RESPONSE_WORDS_SUBAGENT + ) + + # Handle chunking if enabled + if enable_chunking: + formatted, _continuation = ResponseFormatter.chunk_response( + formatted, + max_words, + language + ) + + return formatted + + @staticmethod + def clean_response(text: str) -> str: + """ + Clean up response text (remove extra whitespace, etc.) + + Args: + text: Response text + + Returns: + Cleaned text + """ + # Remove multiple consecutive newlines + cleaned = re.sub(r'\n{3,}', '\n\n', text) + + # Remove trailing whitespace + cleaned = cleaned.strip() + + return cleaned + + @staticmethod + def format_name_of_university(formatted_response, language): + if language == "en": + pattern = r"Universität St\.Gallen" + replace = "University of St.Gallen" + formatted_response = re.sub(pattern, replace, formatted_response) + + return formatted_response diff --git a/src/rag/scope_guardian.py b/src/rag/scope_guardian.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86a129e8621d340ecbfa60c6c8988e120e1750 --- /dev/null +++ b/src/rag/scope_guardian.py @@ -0,0 +1,193 @@ +""" +Scope guardian for handling out-of-scope queries and providing appropriate redirections. +Ensures the chatbot stays within its defined boundaries. +""" +import re +from src.const.agent_response_constants import get_admissions_contact_text +from src.utils.logging import get_logger + +logger = get_logger("scope_guardian") + + +class ScopeGuardian: + """Guards conversation scope and provides appropriate redirections""" + + # Keywords indicating off-topic queries. + # Healthcare/medical terms are intentionally NOT listed — clinicians and + # hospital leaders are a primary target audience for executive education. + OFF_TOPIC_KEYWORDS = { + 'en': [ + 'weather', 'sports', 'politics', 'vacation', 'travel', + 'restaurant', 'movie', 'entertainment', 'news', 'dating', + 'recipe', 'cooking' + ], + 'de': [ + 'wetter', 'sport', 'politik', 'urlaub', 'reise', + 'restaurant', 'film', 'unterhaltung', 'nachrichten', + 'rezept', 'kochen' + ] + } + + # Keywords indicating financial planning requests (out of scope) + FINANCIAL_KEYWORDS = { + 'en': [ + 'loan', 'payment plan', 'installment', 'financing options', + 'budget', 'savings plan', 'personal finance', 'credit', + 'bank loan', 'mortgage', 'scholarship application', + 'detailed funding' + ], + 'de': [ + 'kredit', 'ratenzahlung', 'finanzierung', 'zahlungsplan', + 'budget', 'sparplan', 'persönliche finanzen', 'darlehen', + 'bankkredit', 'stipendium antrag', 'detaillierte finanzierung' + ] + } + + # Keywords indicating aggressive or inappropriate behavior + AGGRESSIVE_KEYWORDS = [ + 'stupid', 'idiot', 'useless', 'terrible', 'worst', 'hate', + 'dumb', 'incompetent', 'pathetic', 'worthless', + 'dumm', 'idiot', 'nutzlos', 'schrecklich', 'hasse' + ] + + @staticmethod + def _matches_any(message_lower: str, keywords: list[str]) -> bool: + """ + Match each keyword as a whole token or whole phrase against the message. + Single-word keywords match on word boundaries; multi-word keywords match + the full phrase. This avoids the previous bug where 'payment plan' was + split into ['payment', 'plan'] and matched if either word appeared + anywhere in the message. + """ + for keyword in keywords: + pattern = rf'\b{re.escape(keyword.lower())}\b' + if re.search(pattern, message_lower): + return True + return False + + @staticmethod + def check_scope(message: str, language: str = 'en') -> str: + """ + Check if message is within scope. + + Args: + message: User message + language: 'en' or 'de' + + Returns: + 'on_topic' | 'off_topic' | 'financial_planning' | 'aggressive' + """ + message_lower = message.lower() + + if ScopeGuardian._matches_any(message_lower, ScopeGuardian.AGGRESSIVE_KEYWORDS): + logger.warning("Detected aggressive language in message") + return 'aggressive' + + off_topic_keywords = ( + ScopeGuardian.OFF_TOPIC_KEYWORDS.get('en', []) + + ScopeGuardian.OFF_TOPIC_KEYWORDS.get('de', []) + ) + if ScopeGuardian._matches_any(message_lower, off_topic_keywords): + logger.info("Detected off-topic query") + return 'off_topic' + + financial_keywords = ( + ScopeGuardian.FINANCIAL_KEYWORDS.get('en', []) + + ScopeGuardian.FINANCIAL_KEYWORDS.get('de', []) + ) + if ScopeGuardian._matches_any(message_lower, financial_keywords): + logger.info("Detected financial planning query") + return 'financial_planning' + + return 'on_topic' + + @staticmethod + def get_redirect_message(scope_type: str, language: str = 'en') -> str: + """ + Get appropriate redirect message based on scope violation. + + Args: + scope_type: Type of scope violation + language: 'en' or 'de' + + Returns: + Redirect message + """ + messages = { + 'off_topic': { + 'en': "I am here to help with questions about HSG Executive MBA programmes (EMBA, IEMBA, and emba X). I would be happy to discuss programme details, admissions requirements, or help you identify the most suitable option for your goals. What would you like to know about our programmes?", + 'de': "Ich bin hier, um Fragen zu den HSG Executive MBA-Programmen (EMBA, IEMBA und emba X) zu beantworten. Gerne helfe ich Ihnen bei Programmdetails, Zulassungsvoraussetzungen oder dabei, das richtige Programm für Ihre Ziele zu finden. Was möchten Sie über unsere Programme wissen?" + }, + 'financial_planning': { + 'en': "For detailed financial planning, payment options, or scholarship applications, I recommend contacting our admissions team directly. They can provide personalised guidance on financing options and available support.\n\nWould you like me to provide general information about programme costs and what is included?", + 'de': "Für detaillierte Finanzplanung, Zahlungsoptionen oder Stipendienanträge empfehle ich, direkt mit unserem Zulassungsteam Kontakt aufzunehmen. Sie können Ihnen persönliche Beratung zu Finanzierungsmöglichkeiten und verfügbarer Unterstützung geben.\n\nMöchten Sie allgemeine Informationen über Programmkosten und Leistungen erhalten?" + }, + 'aggressive': { + 'en': "I am here to help with questions about HSG Executive MBA programmes, but I ask that the conversation remain respectful. If the aggressive language continues, I may need to end the chat and refer you to our admissions team. How can I help you with information about our programmes?", + 'de': "Ich helfe Ihnen gerne bei Fragen zu den HSG Executive MBA-Programmen, aber bitte bleiben Sie respektvoll. Wenn die aggressive Sprache anhält, muss ich das Gespräch ggf. beenden und Sie an unser Zulassungsteam verweisen. Wie kann ich Ihnen bei Informationen über unsere Programme helfen?" + } + } + + return messages.get(scope_type, {}).get(language, messages['off_topic']['en']) + + @staticmethod + def should_escalate( + message: str, + scope_type: str, + attempt_count: int = 1 + ) -> tuple[bool, str]: + """ + Determine if query should be escalated to human advisor. + + Args: + message: User message + scope_type: Type of scope issue + attempt_count: Number of clarification attempts + + Returns: + Tuple of (should_escalate, escalation_message) + """ + # Aggressive behavior -> warn first, then escalate if it continues + if scope_type == 'aggressive': + if attempt_count >= 2: + return True, "escalate_aggressive" + return False, "" + + # Off-topic after 2 redirects -> suggest human contact + if scope_type == 'off_topic' and attempt_count >= 2: + return True, "escalate_off_topic" + + # Complex financial queries -> escalate + if scope_type == 'financial_planning': + return True, "escalate_financial" + + return False, "" + + @staticmethod + def get_escalation_message(escalation_type: str, language: str = 'en') -> str: + """ + Get escalation message for connecting user with admissions team. + + Args: + escalation_type: Type of escalation + language: 'en' or 'de' + + Returns: + Escalation message + """ + messages = { + 'escalate_aggressive': { + 'en': "I cannot continue this chat while the language is aggressive. If you still need support, please contact our admissions team directly.", + 'de': "Ich kann dieses Gespräch nicht fortsetzen, solange die Sprache aggressiv ist. Wenn Sie weiterhin Unterstützung benötigen, kontaktieren Sie bitte unser Zulassungsteam direkt." + }, + 'escalate_off_topic': { + 'en': f"For questions outside programme information, our admissions team would be the best resource. {get_admissions_contact_text('en')}\n\nIs there anything specific about the EMBA, IEMBA, or emba X programmes I can help you with?", + 'de': f"Für Fragen außerhalb der Programminformationen ist unser Zulassungsteam die beste Anlaufstelle. {get_admissions_contact_text('de')}\n\nGibt es etwas Spezifisches über die EMBA-, IEMBA- oder emba X-Programme, bei dem ich Ihnen helfen kann?" + }, + 'escalate_financial': { + 'en': f"Our admissions team can provide detailed guidance on financing options, payment plans, and scholarships. {get_admissions_contact_text('en')}", + 'de': "Unser Zulassungsteam kann Ihnen detaillierte Beratung zu Finanzierungsoptionen, Zahlungsplänen und Stipendien geben. Bitte kontaktieren Sie diese direkt für persönliche Unterstützung bei der Finanzplanung." + } + } + + return messages.get(escalation_type, {}).get(language, messages['escalate_off_topic']['en']) diff --git a/src/rag/utilclasses.py b/src/rag/utilclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..c8723ad2b6dd973172dae738e44a53fb295c4a69 --- /dev/null +++ b/src/rag/utilclasses.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass, field +from typing import List, Literal, Optional + +from pydantic import BaseModel, Field +from typing_extensions import TypedDict +from langchain.agents import AgentState +from langchain_core.messages import AnyMessage + + +@dataclass +class AgentContext: + agent_name: str + + +@dataclass +class LeadAgentQueryResponse: + response: str + language: str + processed_query: str = None + confidence_fallback: bool = False + max_turns_reached: bool = False + should_cache: bool = False + appointment_requested: bool = False + show_booking_widget: bool = False + relevant_programs: List[str] = field(default_factory=list) + + +class StructuredAgentResponse(BaseModel): + response: str = Field(description="Main response to the query.") + is_context_dependent: bool = Field( + default=True, + description=( + "Set to False only if the question can be answered without using any user-specific " + "information (e.g. name, age, preferences, extracted profile data) and without relying " + "on prior conversation turns or conversation history. " + "Must be True for responses involving eligibility, recommendations, comparisons after prior turns, " + "or any answer influenced by user profile data or conversation context." + ) + ) + appointment_requested: bool = Field( + default=False, + description="Set to True ONLY if the user explicitly asks to book, schedule, speak with admissions/an advisor, see appointment slots, or accepts a previous consultation offer. Routine pricing, comparisons, recommendations, and exploratory fit questions must be False." + ) + show_booking_widget: bool = Field( + default=False, + description="Set to True ONLY when appointment_requested is True and the booking widget should be shown now. Never use this for soft contact mentions or routine informational answers." + ) + relevant_programs: Optional[List[Literal["emba", "iemba", "emba_x"]]] = Field( + default=None, + description="If appointment_requested is True, list the programs relevant to the user. Options: 'emba', 'iemba', 'emba_x'. If the user is undecided or general, leave this list empty." + ) + + +class State(TypedDict): + messages: list[AnyMessage] + answer: str + + +class ConversationState(TypedDict): + """Tracks user profile and conversation context""" + user_id: str # Unique session identifier + user_language: str | None # Locked after first message + user_name: str | None # User's name extracted from conversation + experience_years: int | None # Years of professional experience + leadership_years: int | None # Years of leadership experience + field: str | None # Professional field/industry + interest: str | None # Content interests + qualification_level: str | None # "bachelor", "master", "MBA", etc. + program_interest: list[str] # ["EMBA", "IEMBA", "EMBAX"] + suggested_program: str | None # Recommended program based on conversation + handover_requested: bool | None # True if appointment requested, False if declined, None if session active + topics_discussed: list[str] # Track what's been covered + preferences_known: bool # Whether we have enough context + + +class LeadInformationState(AgentState): + lead_name: str + lead_age: int + lead_language_knowledge: list + lead_work_experience: dict + lead_motivation: list + # Enhanced state tracking + conversation_state: ConversationState diff --git a/src/scraping/__init__.py b/src/scraping/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/scraping/content_cleaner.py b/src/scraping/content_cleaner.py new file mode 100644 index 0000000000000000000000000000000000000000..2de5f53b5c91a458ad6992274d8d3dae6f3ef030 --- /dev/null +++ b/src/scraping/content_cleaner.py @@ -0,0 +1,134 @@ +import json, os + +from typing import Counter +from bs4 import BeautifulSoup +from docling_core.types.doc.document import DoclingDocument + +from ..const.cc_whitelist import REPETITION_WHITELIST +from ..utils.logging import get_logger +from ..config import config + +logger = get_logger('scraper.cleaning') + +class ContentCleaner: + def __init__(self, full_scraping) -> None: + self._repetitions_counter: Counter = Counter() + self._repetitive_content: list[str] = [] + self.full_scraping: bool = full_scraping + + + def clean_mobile_content(self, html: str) -> str: + soup = BeautifulSoup(html, 'html') + for element in soup.find_all(class_='show-sm'): + element.decompose() + + return str(soup) + + + def extract_urls(self, document: DoclingDocument) -> list[str]: + discovered_urls = [] + for node, _ in document.iterate_items(root=document.body, with_groups=False): + if hasattr(node, 'hyperlink') and node.hyperlink: + discovered_urls.append(str(node.hyperlink)) + + return discovered_urls + + + def collect_repetitive_content(self, document: DoclingDocument) -> None: + content_in_document = set() + for node, _ in document.iterate_items(root=document.body, with_groups=False): + if hasattr(node, 'text') and node.text: + stripped_text = node.text.strip().lower() + content_in_document.add(stripped_text) + + for content in content_in_document: + self._repetitions_counter[content] += 1 + + + def perform_content_analysis(self,target_url: str = "index", url_filename: str = 'index', ) -> None: + target_url_filename = url_filename + '-content_analysis.json' + target_url_path = os.path.join(config.paths.SCRAPING_OUTPUT, target_url_filename) + + if not self.full_scraping and os.path.exists(target_url_path): + with open(target_url_path, 'r') as f: + content_analysis = json.load(f) + self._repetitive_content = content_analysis['repetitive_content'] + else: + self._repetitive_content = [{'content': text, 'amount': count} + for text, count in self._repetitions_counter.items() + if text not in REPETITION_WHITELIST and count > 1] + logger.info(f"Content analysis for target URL '{target_url}' " + + f"yielded {len(self._repetitive_content)} repetitive text lines") + + content_analysis = { + 'target_url': target_url, + 'repetitive_content': self._repetitive_content, + } + + with open(target_url_path, 'w') as f: + json.dump(content_analysis, f, indent=4) + logger.info(f"Saved content analysis results under '{target_url_path}'") + + self._repetitive_content = [rc['content'] for rc in self._repetitive_content] + + + def clean_document(self, document: DoclingDocument) -> None: + document.furniture.children.clear() + + # Step 1: Shallow tagging of useless content + texts_to_remove = set() + nodes_to_remove = [] + for node, _ in document.iterate_items(root=document.body, with_groups=False): + if hasattr(node, 'text') and node.text: + stripped_text = node.text.strip().lower() + if stripped_text in self._repetitive_content: + nodes_to_remove.append(node) + continue + if hasattr(node, 'captions') and node.captions: + caption_text = node.caption_text(document).strip() + if len(caption_text) < 50: + nodes_to_remove.append(node) + if caption_text not in self._repetitive_content: + texts_to_remove.add(caption_text) + continue + if hasattr(node, 'hyperlink') and node.hyperlink: + nodes_to_remove.append(node) + if node.text: + texts_to_remove.add(node.text) + continue + + # Step 2: Removal of duplicates from other node types + for node, _ in document.iterate_items(root=document.body, with_groups=False): + if hasattr(node, 'text') and node.text: + stripped_text = node.text.strip().lower() + if stripped_text in texts_to_remove: + nodes_to_remove.append(node) + continue + + # Step 3: Deletion of all useless nodes + for node in nodes_to_remove: + if not (hasattr(node, 'parent') and node.parent): + continue + + parent_node = node.parent.resolve(document) + node_ref = node.get_ref() + if node_ref not in parent_node.children: + continue + + node_children_refs = list(node.children) if hasattr(node, 'children') else [] + idx = parent_node.children.index(node_ref) + parent_node.children.pop(idx) + parent_node.children[idx:idx] = node_children_refs + + # Promote children of removed node to node's parent + for child_ref in node_children_refs: + child_node = child_ref.resolve(document) + if hasattr(child_node, 'parent'): + child_node.parent = node.parent + + # Clean node references + if hasattr(node, 'children'): + node.children.clear() + if hasattr(node, 'parent'): + node.parent = None + diff --git a/src/scraping/html_processor.py b/src/scraping/html_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..1779f94b5968ecad0fd0afe3a6a3425bdec54d4a --- /dev/null +++ b/src/scraping/html_processor.py @@ -0,0 +1,197 @@ +from docling.document_converter import InputFormat +from docling_core.types.doc.document import DoclingDocument, TitleItem + +from .types import ChunkMetadata + +from ..config import config +from ..pipeline.processors import ProcessorBase +from ..utils.logging import get_logger + +logger = get_logger('scraper.processor') + +class HTMLProcessor(ProcessorBase): + def __init__(self) -> None: + super().__init__() + + def process(self, url: str, html_content: str) -> DoclingDocument | None: + if not html_content: + logger.warning('Nothing to process, HTML body is empty!') + return None + + logger.info(f"Analyzing page layout of URL '{url}'...") + try: + document = self._converter.convert_string(html_content, InputFormat.HTML).document + document.name = url + return document + except Exception as e: + logger.error(f"Failed to analyze page layout: {e}") + return None + + + def prepare_chunks(self, url: str, url_text: str, metas: list[ChunkMetadata]) -> dict[str, list]: + prepared_chunks = { lang: [] for lang in config.get('AVAILABLE_LANGUAGES', ['en', 'de']) } + for meta in metas: + prepared_chunks[meta.language].append(meta.text) + for lang, chunks in prepared_chunks.items(): + prepared_chunks[lang] = self._prepare_chunks(url, url_text, chunks) + + return prepared_chunks + + + def extract_title(self, document: DoclingDocument) -> str: + titles = [title.text for title in document.texts if isinstance(title, TitleItem)] + return titles[0] if titles else 'No Title' + + + def chunk(self, document: DoclingDocument) -> list[dict]: + raw_chunks = list(self._chunker.chunk(document)) + chunks = self._merge_chunks_by_headings(raw_chunks) + + prepared_chunks = [{ + 'text': chunk, + 'title': chunk.split('\n')[0], + 'size': self._chunker.tokenizer.count_tokens(chunk) + } for chunk in chunks] + + return prepared_chunks + + + def merge_chunks_by_topic(self, chunk_metadatas: list[ChunkMetadata]) -> list[ChunkMetadata]: + MAX_TOKENS = config.processing.MAX_TOKENS + merged_chunks = [] + + current_group = [] + current_tokens = 0 + current_topic = None + + for chunk in chunk_metadatas: + topic = chunk.topic + token_size = chunk.token_size + + # If the chunk is already large enough, it will not be merged + if token_size >= MAX_TOKENS: + # Consequtive group is over when large chunk is met + if current_group: + merged_chunks.append(self._create_merged_chunk(current_group)) + current_group = [] + current_tokens = 0 + current_topic = None + + # Large chunk is appended here + merged_chunks.append(chunk) + continue + + if (current_topic and topic != current_topic) or (current_tokens + token_size > MAX_TOKENS): + if current_group: + merged_chunks.append(self._create_merged_chunk(current_group)) + + current_group = [chunk] + current_tokens = token_size + current_topic = topic + continue + + current_group.append(chunk) + current_tokens += token_size + current_topic = topic + + + if current_group: + merged_chunks.append(self._create_merged_chunk(current_group)) + + return merged_chunks + + + def _create_merged_chunk(self, group: list[dict]) -> ChunkMetadata: + if len(group) == 1: + return group[0] + + merged_text = "\n".join(cm.text for cm in group).strip() + total_tokens = sum(cm.token_size for cm in group) + + first = group[0] + + merged_id = f"merged_{first.topic}_{group[0].chunk_id}_to_{group[-1].chunk_id}" + merged_chunk = ChunkMetadata( + chunk_id = merged_id, + text = merged_text, + source_url = first.source_url, + program = first.program, + language = first.language, + topic = first.topic, + last_scraped = first.last_scraped, + page_title = first.page_title, + section_heading = first.section_heading, + token_size = total_tokens, + original_chunk_ids = [c.chunk_id for c in group], + ) + return merged_chunk + + + def _get_formatted_chunk_text(self, chunk, headings) -> str: + formatted_text = f"{' '.join(headings)}\n" + + if not hasattr(chunk.meta, 'doc_items'): + return formatted_text + chunk.text.replace('\n', ' ') + + labels = set() + for item in chunk.meta.doc_items: + labels.add(item.label) + + labels = [label for label in labels if label in ['table', 'list_item']] + if labels: + return formatted_text + chunk.text + + return formatted_text + chunk.text.replace('\n', ' ') + + + def _merge_chunks_by_headings(self, raw_chunks: list) -> list[str]: + """ + Groups consecutive chunks that share the same parent headings and merges them into one clean chunk. + """ + prefix_level = 2 + merged = [] + i = 0 + n = len(raw_chunks) + + while i < n: + chunk = raw_chunks[i] + headings = getattr(chunk.meta, "headings", []) or [] + + if len(headings) < prefix_level: + formatted_text = self._get_formatted_chunk_text(chunk, headings) + merged.append(formatted_text) + i += 1 + continue + + # Start a new group with this prefix + common_prefix = "\n".join(headings[:prefix_level]) + group = [] + + while i < n: + curr_chunk = raw_chunks[i] + curr_headings = getattr(curr_chunk.meta, "headings", []) or [] + curr_prefix = "\n".join(curr_headings[:prefix_level]) + + if curr_prefix != common_prefix: + break + + leaf_heading = curr_headings[-1] if len(curr_headings) > prefix_level else "" + content = curr_chunk.text.replace('\n', ' ').strip() + + if leaf_heading and content: + group.append(f"{leaf_heading}: {content}") + elif content: + group.append(content) + + i += 1 + + # Build the final merged chunk + if len(group) > 1: + full_chunk = f"{'\n'.join(headings[1:-1])}\n{'\n'.join(group)}" + else: + full_chunk = f"{'\n'.join(headings[1:])}\n{chunk.text}" + + merged.append(full_chunk.strip()) + + return merged + diff --git a/src/scraping/scraper.py b/src/scraping/scraper.py new file mode 100644 index 0000000000000000000000000000000000000000..d09586c2245c10e680fd0743cf83b0e406498d67 --- /dev/null +++ b/src/scraping/scraper.py @@ -0,0 +1,716 @@ +import os, shutil, json +from datetime import datetime +from collections import Counter, defaultdict +from urllib.parse import urlsplit +from urllib.robotparser import RobotFileParser +from usp.objects.sitemap import InvalidSitemap +from usp.tree import sitemap_tree_for_homepage + +from src.notification.notification_center import NotificationCenter + +from .utils import * +from .types import * +from .html_processor import HTMLProcessor +from .content_cleaner import ContentCleaner +from .url_normalizer import UrlNormalizer + +from ..utils.lang import detect_language +from ..utils.logging import get_logger +from ..utils.tools import call_with_exponential_backoff +from ..config import config + +logger = get_logger('scraper.core') +incupd_logger = get_logger('scraper.incremental_updates') + + +class Scraper: + def __init__(self, scrape_all: bool = True) -> None: + self._scrape_all = scrape_all + self._path = config.paths + self._processor: HTMLProcessor = HTMLProcessor() + self._normalizer: UrlNormalizer = UrlNormalizer() + self._content_cleaner: ContentCleaner = ContentCleaner(self._scrape_all) + self._notif_center: NotificationCenter = NotificationCenter() + + self._make_directories() + + self._url_temp_timestamps: dict[str, UrlTimestamps] = {} + self._url_timestamps: dict[str, UrlTimestamps] = self._load_data(self._path.SCRAPING_OUTPUT, 'url_timestamps') + self._url_priorities: dict[str, list[str]] = self._load_data(self._path.URLS_OUTPUT, 'url_priorities') + + logger.info(f'Successfully initialized the scraper') + if scrape_all: + logger.info("Initialized with SCRAPE_ALL=True. Timestamps and priorities will be ignored for this scraping session") + + + def _make_directories(self) -> None: + os.makedirs(self._path.URLS_OUTPUT, exist_ok=True) + os.makedirs(self._path.CHUNKS_OUTPUT, exist_ok=True) + os.makedirs(self._path.TEMP_CHUNKS_OUTPUT, exist_ok=True) + os.makedirs(self._path.SCRAPING_OUTPUT, exist_ok=True) + os.makedirs(self._path.RAW_HTML_OUTPUT, exist_ok=True) + os.makedirs(self._path.RAW_TEXT_OUTPUT, exist_ok=True) + os.makedirs(self._path.METADATA_OUTPUT, exist_ok=True) + os.makedirs(self._path.EXTRACTED_TEXT_OUTPUT, exist_ok=True) + + + def scrape_target(self, target_url: str) -> list[ChunkMetadata]: + # Step 1: Analyze the target URL for availability, robots and sitemap + analyzed_domain = self._analyze_domain(target_url) + if not analyzed_domain: + logger.error(f"Failed to scrape target URL {target_url}") + return {} + + sitemap_urls = analyzed_domain.urls + self._save_results(self._path.URLS_OUTPUT, 'sitemap_urls', sitemap_urls, target_url) + + # Step 2: Validate and scrape URLs listed in the sitemap + analyzed_sitemap = self._analyze_sitemap(analyzed_domain) + + documents = analyzed_sitemap.documents + + logger.info(f"Indexed {len(sitemap_urls)} sitemap URLs for target URL {target_url}") + logger.info(f"Scraped {len(documents)} unique URLs (others were either redirects or blacklisted)") + + # Step 3: Analyze discovered URLs and search for the new ones + discovered_urls = analyzed_sitemap.discovered_urls + logger.info(f"Discovered {len(discovered_urls)} new URLs during sitemap analysis") + + analyzed_discoveries = self._analyze_discoveries(discovered_urls, sitemap_urls, analyzed_domain) + + discovered_urls = analyzed_discoveries.discovered_urls + self._save_results(self._path.URLS_OUTPUT, 'discovered_urls', discovered_urls, target_url) + + documents.extend(analyzed_discoveries.documents) + + logger.info(f"Indexed {len(discovered_urls)} new URLs for target URL {target_url}") + + # Step 4: Load temp chunks first so resume works even when there are no new documents. + temp_filename = self._get_temp_chunks_filename(target_url) + temp_merged_chunks = self._load_data(self._path.TEMP_CHUNKS_OUTPUT, temp_filename) + + if not documents and not temp_merged_chunks: + logger.info(f"No new content was scraped from the target URL {target_url}") + return {} + + tagged_documents = [] + # Step 5: Analyze the converted URLs + if documents: + self._content_cleaner.perform_content_analysis(target_url, self._normalizer.url_to_filename(target_url)) + analyzied_documents = self._analyze_url_documents(documents) + + self._save_results(self._path.URLS_OUTPUT, 'url_tags', analyzied_documents.url_tags) + self._save_results(self._path.URLS_OUTPUT, 'url_priorities', analyzied_documents.url_priorities) + tagged_documents = analyzied_documents.tagged_documents + + # Step 6: Collect and save chunks + chunk_metadatas = self._collect_chunks(tagged_documents, target_url, temp_merged_chunks) + + self._save_results(self._path.METADATA_OUTPUT, 'raw_chunk_metadata', chunk_metadatas['raw'], target_url) + self._save_results(self._path.METADATA_OUTPUT, 'merged_chunk_metadata', chunk_metadatas['merged'], target_url) + self._save_results(self._path.METADATA_OUTPUT, 'deleted_chunk_metadata', chunk_metadatas['deleted'], target_url) + + logger.info(f"Collected {len(chunk_metadatas['merged'])} chunks from target URL {target_url}") + + logger.info(f"Scraping finished for target URL '{target_url}'") + + return chunk_metadatas['final'] + + + def _analyze_domain(self, target_url: str) -> DomainAnalysisReport | None: + if not target_url: + logger.warning('The target URL string is empty!') + return None + + # Step 1: Test whether the target URL is even accessible before initializing the scraping procedure + response = call_with_exponential_backoff(fetch_url, args=(target_url,)) + if response['status'] == 'FAIL': + logger.error(f"Unaccessible target URL '{target_url}': {response['last_error']}") + return None + if not response['result']: + logger.warning(f"Unnaccessible target URL '{target_url}': Recieved client/server error!") + return None + + # Step 2: Fetch and parse robots + logger.info(f"Fetching 'robots.txt' for the target URL '{target_url}'...") + robots_parser: RobotFileParser = parse_robots(target_url) + + if not robots_parser: + logger.warning( + f"Could not fetch the 'robots.txt' file for the target URL '{target_url}'! " + + "(Are you sure the scraping begins from root?)" + ) + return None + + logger.info(f"Parsed the 'robots.txt' file for target URL '{target_url}'") + + delay = robots_parser.crawl_delay('scraper') + target_domain = urlsplit(target_url).netloc + + # Step 3: Fetch and parse sitemaps + logger.info(f"Fetching sitemaps for target URL {target_url}...") + sitemap_tree = sitemap_tree_for_homepage(target_url) + if isinstance(sitemap_tree, InvalidSitemap): + logger.error(f"Cannot fetch sitemap for target URL '{target_url}': Invalid sitemap structure!") + return None + + page_data = [] + page_urls = set() + for page in sitemap_tree.all_pages(): + page_url = page.url + if not robots_parser.can_fetch('scraper', page_url) or page_url in page_urls: + continue + + page_urls.add(page_url) + page_data.append(PageData(page_url, page.last_modified)) + + logger.info(f'Loaded sitemaps with {len(page_data)} pages') + + return DomainAnalysisReport( + target = target_domain, + urls = list(page_urls), + pages = page_data, + delay = delay, + ) + + def _analyze_sitemap(self, domain: DomainAnalysisReport) -> UrlAnalysisReport: + documents = [] + visited_urls = set() + discovered_urls = set() + rejected_urls = [] + + sitemap_pages = domain.pages + logger.info(f'Starting validation and scraping for sitemap URLs...') + for page in sitemap_pages: + result = self._scrape_page(page.url, domain.delay, visited_urls, last_modified=page.last_modified) + visited_urls.add(page.url) + + if result.status != ScrapingStatus.OK: + if result.status == ScrapingStatus.REJECTED: + rejected_urls.append(page.url) + continue + + final_url = result.final_url + documents.append(result.document) + visited_urls.add(final_url) + + self._store_timestamps(final_url, result.timestamps, temp=True) + + new_urls = self._normalizer.filter_discovered_urls(result.discovered_urls, visited_urls, domain.target) + discovered_urls |= new_urls + + if len(rejected_urls) > len(sitemap_pages)*0.1: + rejection_rate = len(rejected_urls)/len(sitemap_pages) + logger.warning(f"Rejection rate is {rejection_rate}") + self._notif_center.send_notification( + subject = "⚠ WARNING: Scraping rejection rate is >10%!", + body = f"Rejection rate: {int(rejection_rate*100)}%\n" + + f"Failed to scrape following URLs for target domain {domain.target}:\n" + + "\n".join([f"\t- {url}" for url in rejected_urls]), + channel = "slack", + ) + + discovered_urls = [url for url in discovered_urls if url not in visited_urls] + return UrlAnalysisReport( + documents = documents, + discovered_urls = discovered_urls, + ) + + def _analyze_discoveries( + self, + discovered_urls: list, + sitemap_urls: list, + domain: DomainAnalysisReport + ) -> UrlAnalysisReport: + if len(discovered_urls) == 0: + return UrlAnalysisReport([], []) + + documents = [] + discoveries = discovered_urls.copy() + visited_urls = set(sitemap_urls.copy()) + + discovered_urls = [{'url': url, 'depth': 0} for url in discovered_urls] + logger.info(f"Starting validation and scraping for discovered URLs...") + while discovered_urls: + discovered_url = discovered_urls.pop() + url = discovered_url['url'] + + result = self._scrape_page(url, domain.delay, visited_urls, discovery_depth=discovered_url['depth']) + visited_urls.add(url) + + if not result: continue + + final_url = result.final_url + documents.append(result.document) + visited_urls.add(final_url) + discoveries.append(final_url) + + self._store_timestamps(final_url, result.timestamps, temp=True) + + for new_url in self._normalizer.filter_discovered_urls(result.discovered_urls, visited_urls, domain.target): + discovered_urls.append({'url': new_url, 'depth': result.discovery_depth}) + + return UrlAnalysisReport( + documents = documents, + discovered_urls = discoveries, + ) + + def _analyze_url_documents(self, documents: list) -> DocumentAnalysisReport: + url_tags = {} + url_priorities = defaultdict(list) + tagged_documents = [] + + logger.info(f"Analyzing scraped contents of {len(documents)} pages...") + for document in documents: + url = document.name + self._content_cleaner.clean_document(document) + + extracted_text = self._processor.convert_to_txt(document) + if extracted_text.strip() == '': + logger.warning(f'No text extracted from {url}. Skipping ...') + continue + url_filename = self._normalizer.url_to_filename(url) + extracted_text_file_path = os.path.join(self._path.EXTRACTED_TEXT_OUTPUT, url_filename + '.txt') + + with open(extracted_text_file_path, 'w', encoding="utf-8") as f: + f.write(extracted_text) + logger.info(f"Saved extracted text for URL '{url}' under '{extracted_text_file_path}'") + + language = detect_language(extracted_text) + tp_result = detect_page_topic_and_priority(extracted_text) + programs = self._processor.strategies_processor.apply_strategy( + strategy_name='programs', + arguments={'document_content': extracted_text}, + ) + program = programs[0] if programs else 'no program' + + tags = UrlTags( + topic = tp_result['topic'], + priority = tp_result['priority'], + language = language, + program = program, + ) + + url_tags[url] = tags + url_priorities[tp_result['priority']].append(url) + + tagged_documents.append(TaggedDocument(document, DocumentTags(program, language))) + + return DocumentAnalysisReport( + url_tags = url_tags, + url_priorities = url_priorities, + tagged_documents = tagged_documents, + ) + + def _collect_chunks( + self, + tagged_documents: list[dict], + target_url: str, + temp_chunks: dict[str, list[ChunkMetadata]] | None = None, + ) -> dict[str, list[ChunkMetadata]]: + raw_chunks = [] + deleted_chunks = [] + merged_chunks, final_chunks = self._read_temp_chunks(temp_chunks, tagged_documents) + + program_counter = self._build_program_counter_from_merged_chunks(merged_chunks) + + if merged_chunks: incupd_logger.info(f"Restored {len(merged_chunks)} chunks from temp") + + for entry in tagged_documents: + document = entry.document + program = entry.tags.program + language = entry.tags.language + url = document.name + url_filename = self._normalizer.url_to_filename(url) + + program_counter[program] += 1 + + doc_chunks_dir_path = os.path.join(config.paths.CHUNKS_OUTPUT, url_filename) + if os.path.exists(doc_chunks_dir_path): shutil.rmtree(doc_chunks_dir_path) + os.makedirs(doc_chunks_dir_path) + + mergible_chunks_metadatas = [] + raw_chunk_count = 0 + for i, chunk in enumerate(self._processor.chunk(document), start=1): + raw_chunk_count = i + chunk_file_path = os.path.join(doc_chunks_dir_path, f"chunk_{i}.txt") + with open(chunk_file_path, 'w', encoding="utf-8") as f: + f.write(chunk['text']) + + chunk_topic = detect_chunk_topic(chunk['text']) + chunk_metadata = ChunkMetadata( + chunk_id = f"{program.lower()}_{program_counter[program]:03d}_{i:02d}", + text = chunk['text'], + source_url = url, + program = program, + language = language, + topic = chunk_topic, + last_scraped = datetime.now(), + page_title = self._processor.extract_title(document), + section_heading = chunk['title'], + token_size = chunk['size'], + ) + raw_chunks.append(chunk_metadata) + if chunk_topic == 'none': + deleted_chunks.append(chunk_metadata) + else: + mergible_chunks_metadatas.append(chunk_metadata) + + logger.info(f"Collected {raw_chunk_count} raw chunks and saved under '{doc_chunks_dir_path}'") + + merged_chunk_metadatas = self._processor.merge_chunks_by_topic(mergible_chunks_metadatas) + merged_chunks.extend(merged_chunk_metadatas) + + self._store_temp_chunks(target_url, url, merged_chunk_metadatas) + logger.info(f"Merged {raw_chunk_count} raw chunks into {len(merged_chunk_metadatas)} chunks by topic") + + prepared_chunks = self._processor.prepare_chunks(url, self._processor.convert_to_txt(document), merged_chunk_metadatas) + for lang in final_chunks.keys(): + if lang in prepared_chunks.keys(): + final_chunks[lang].extend(prepared_chunks[lang]) + + return { + 'raw': raw_chunks, + 'merged': merged_chunks, + 'deleted': deleted_chunks, + 'final': final_chunks, + } + + + def _read_temp_chunks( + self, + temp_chunks: dict[str, list[ChunkMetadata]], + tagged_documents: list[TaggedDocument] + ) -> set[list, list[dict]]: + loaded_temp_chunks = temp_chunks.copy() + prepared_temp_chunks = {lang: [] for lang in config.get('AVAILABLE_LANGUAGES', ['en', 'de'])} + + for url in [entry.document.name for entry in tagged_documents]: + if url in temp_chunks.keys(): + incupd_logger.info(f"Deleted stored temp data for URL {url} as it was newly scraped") + del loaded_temp_chunks[url] + + restored_temp_chunks = [] + for url, chunks in loaded_temp_chunks.items(): + url_filename = self._normalizer.url_to_filename(url) + extracted_text_path = os.path.join(self._path.EXTRACTED_TEXT_OUTPUT, url_filename + '.txt') + if not os.path.exists(extracted_text_path): + incupd_logger.warning(f"Cannot restore chunks for URL {url}: Failed to locate previously extracted contents!") + incupd_logger.warning(f"This URL will has to be rescraped in the next session") + continue + + with open(extracted_text_path, 'r') as f: + url_text = f.read() + + prepared_chunks = self._processor.prepare_chunks(url, url_text, chunks) + for lang in prepared_temp_chunks.keys(): + if lang in prepared_chunks.keys(): + prepared_temp_chunks[lang].extend(prepared_chunks[lang]) + + restored_temp_chunks.extend(chunks) + incupd_logger.info(f"Restored {len(chunks)} chunks for URL {url} from temp") + + return restored_temp_chunks, prepared_temp_chunks + + + def _store_temp_chunks(self, target_url: str, url: str, chunks: list[ChunkMetadata]) -> None: + self._url_timestamps[url] = self._url_temp_timestamps[url] + + temp_chunks = {url: chunks} + + self._save_results(self._path.TEMP_CHUNKS_OUTPUT, self._get_temp_chunks_filename(target_url), temp_chunks) + self._save_results(self._path.SCRAPING_OUTPUT, 'url_timestamps', self._url_timestamps) + + incupd_logger.info(f"Stored {len(chunks)} chunks in temp for URL {url}") + + + def _build_program_counter_from_merged_chunks(self, merged_chunks: list[ChunkMetadata]) -> Counter: + counter = Counter() + seen = set() + + for chunk in merged_chunks: + key = (chunk.program, chunk.source_url) + if key not in seen: + counter[chunk.program] += 1 + seen.add(key) + + return counter + + def _is_url_modified( + self, + url: str, + new_last_modified: datetime | None = None, + new_page_hash: str | None = None + ) -> bool: + if url not in self._url_timestamps.keys(): + return True + + stored = self._url_timestamps[url] + + if stored.last_modified and new_last_modified: + return stored.last_modified < new_last_modified + + if new_page_hash and stored.page_hash: + return new_page_hash != stored.page_hash + + return True + + + def _store_timestamps(self, url: str, timestamps: UrlTimestamps, temp=False) -> None: + if temp: + self._url_temp_timestamps[url] = timestamps + else: + self._url_timestamps[url] = timestamps + + + def _get_temp_chunks_filename(self, target_url: str) -> str: + return self._normalizer.url_to_filename(target_url) + '_merged_chunks' + + + def delete_temp_merged_chunks(self, target_url: str) -> None: + temp_path = os.path.join( + self._path.TEMP_CHUNKS_OUTPUT, + self._get_temp_chunks_filename(target_url) + '.json' + ) + if os.path.exists(temp_path): + os.remove(temp_path) + incupd_logger.info(f"Deleted temp merged chunks file '{temp_path}'") + + + def _get_etag(self, url: str) -> str | None: + if url not in self._url_timestamps.keys(): + return None + + return self._url_timestamps[url].etag + + def _is_fetch_valid(self, url: str, visited_urls: list[str], fetch_result: FetchResult) -> ScrapingStatus: + if not fetch_result: + logger.warning(f"Cannot fetch {url}! Skipping...") + return ScrapingStatus.REJECTED + + if fetch_result.not_modified: + logger.info("No updates on the page, skipping...") + return ScrapingStatus.NO_UPDATES + + final_url = fetch_result.final_url + if final_url != url: + logger.info(f"Redirect detected: '{url}' --> '{final_url}'") + if final_url in visited_urls: + logger.info(f"'{final_url}' was already visited, skipping...") + return ScrapingStatus.VISITED + logger.info(f"Continuing with URL '{final_url}'...") + + last_modified = fetch_result.last_modified + page_hash = fetch_result.page_hash + if not self._scrape_all and not self._is_url_modified(final_url, new_last_modified=last_modified, new_page_hash=page_hash): + logger.info(f"URL {final_url} was not modified since last scraping session, skipping...") + return ScrapingStatus.NO_UPDATES + + return ScrapingStatus.OK + + + def _is_url_prioritized(self, url) -> bool: + if url not in self._url_timestamps.keys(): + return True + + for prio, urls in self._url_priorities.items(): + if url in urls: + return self._is_scraping_scheduled(url, prio) + + return True + + + def _is_scraping_scheduled(self, url, prio) -> bool: + current_timestamp = datetime.now() + saved_timestamp = self._url_timestamps[url].last_scraped + time_difference = current_timestamp - saved_timestamp + + if not saved_timestamp: + return True + + for interval_prio, interval in config.scraping.INTERVALS.items(): + if prio == interval_prio: + return time_difference.days >= interval + + return True + + + def _scrape_page( + self, url: str, + crawl_delay: float, + visited_urls: list[str], + discovery_depth: int = 0, + last_modified: datetime | None = None + ) -> ScrapingResult | None: + if not url: + return ScrapingResult(status=ScrapingStatus.REJECTED) + + if self._normalizer.is_url_blacklisted(url): + logger.info(f"URL {url} is blacklisted by scraper, skipping...") + return ScrapingResult(status=ScrapingStatus.BLACKLISTED) + + if url in visited_urls: + logger.info(f'URL {url} was already analyzed via redirect, skipping...') + return ScrapingResult(status=ScrapingStatus.VISITED) + + if not self._scrape_all and last_modified and not self._is_url_modified(url, new_last_modified=last_modified): + logger.info(f"URL '{url}' was not modified since last scraping session, skipping...") + self._url_timestamps[url].last_modified = last_modified + return ScrapingResult(status=ScrapingStatus.NO_UPDATES) + + if not self._scrape_all and not self._is_url_prioritized(url): + logger.info(f"URL {url} is not prioritized, skipping") + return ScrapingResult(status=ScrapingStatus.NO_UPDATES) + + logger.info(f"Fetching head for URL '{url}'...") + + etag = self._get_etag(url) + response = call_with_exponential_backoff(fetch_head, args=(url, etag), delay=crawl_delay) + if response['status'] == 'FAIL': + logger.warning(f"Failed to fetch head for URL {url}: {response['last_error']}! Skipping...") + return ScrapingResult(status=ScrapingStatus.REJECTED) + + fetch_result = response['result'] + validation = self._is_fetch_valid(url, visited_urls, fetch_result) + if validation != ScrapingStatus.OK: + return ScrapingResult(status=validation) + + response = call_with_exponential_backoff(fetch_url, args=(url, etag), delay=crawl_delay) + if response['status'] == 'FAIL': + logger.warning(f"Failed to fetch URL {url}: {response['last_error']}! Skipping...") + return ScrapingResult(status=ScrapingStatus.REJECTED) + + fetch_result = response['result'] + validation = self._is_fetch_valid(url, visited_urls, fetch_result) + if validation != ScrapingStatus.OK: + return ScrapingResult(status=validation) + + if not fetch_result.last_modified: + logger.warning("No information about URL last modification date exists!") + + timestamps = UrlTimestamps( + last_modified = fetch_result.last_modified, + last_scraped = datetime.now(), + etag = fetch_result.etag, + page_hash = fetch_result.page_hash, + ) + + raw_html = fetch_result.text + final_url = fetch_result.final_url + + url_filename = self._normalizer.url_to_filename(final_url) + raw_html_file_path = os.path.join(config.paths.RAW_HTML_OUTPUT, url_filename + '.html') + with open(raw_html_file_path, 'w', encoding="utf-8") as f: + f.write(raw_html) + logger.info(f"Saved fetched HTML under '{raw_html_file_path}'") + + logger.info(f"Cleaning URL {final_url} from mobile data...") + cleaned_html = self._content_cleaner.clean_mobile_content(raw_html) + + logger.info(f"Processing URL {final_url}...") + document = self._processor.process(final_url, cleaned_html) + + if not document: + logger.warning(f"Failed to process URL '{final_url}'! Skipping...") + return ScrapingResult(status=ScrapingStatus.REJECTED) + + discovered_urls = self._content_cleaner.extract_urls(document) if discovery_depth <= 3 else [] + self._content_cleaner.collect_repetitive_content(document) + + raw_text = self._processor.convert_to_txt(document) + raw_text_file_path = os.path.join(config.paths.RAW_TEXT_OUTPUT, url_filename + '.txt') + with open(raw_text_file_path, 'w', encoding="utf-8") as f: + f.write(raw_text) + logger.info(f"Saved raw text for URL '{final_url}' under '{raw_text_file_path}'") + + return ScrapingResult( + document = document, + discovered_urls = discovered_urls, + final_url = final_url, + timestamps = timestamps, + discovery_depth = discovery_depth + 1, + status = ScrapingStatus.OK, + ) + + def _save_results(self, path: str, filename: str, results, target_url: str | None = None) -> None: + results_path = os.path.join(path, filename + '.json') + + results_dict = {} + if os.path.exists(results_path): + try: + with open(results_path, 'r', encoding='utf-8') as f: + results_dict = json.load(f) + except Exception: + logger.warning(f"Failed to load existing {results_path}, will overwrite") + + match filename: + case 'url_tags': + results_dict |= results + + case 'url_timestamps': + for url, ts in results.items(): + results_dict[url] = dataclass_to_dict(ts) + + case 'url_priorities': + for prio, urls in results.items(): + prev = set(results_dict.get(prio, [])) + results_dict[prio] = list(prev.union(urls)) + + case _ if filename.endswith('_merged_chunks'): + for url, chunks in results.items(): + results_dict[url] = [dataclass_to_dict(chunk) for chunk in chunks] + + case _: + results = [dataclass_to_dict(r) for r in results] + if target_url: + results_dict[target_url] = results + else: + results_dict = results + + try: + with open(results_path, 'w', encoding='utf-8') as f: + json.dump( + results_dict, + f, + indent=4, + default=lambda o: o.isoformat() if isinstance(o, datetime) else None, + ) + except Exception as e: + logger.error(f"Failed to store results '{filename}'") + raise e + + logger.debug(f"Stored results in file {results_path}") + + + def _load_data(self, path: str, filename: str): + datapath = os.path.join(path, filename + '.json') + + if not os.path.exists(datapath): + logger.warning(f"Failed to locate file {datapath}; new data will be recorded") + return defaultdict(dict) + + try: + with open(datapath, 'r', encoding='utf-8') as f: + loaded_data = json.load(f) + + match filename: + case 'url_timestamps': + for url, ts_dict in loaded_data.items(): + loaded_data[url] = dict_to_dataclass(ts_dict, UrlTimestamps) + incupd_logger.debug(f"Loaded {len(loaded_data)} URL timestamps") + return loaded_data + + case _ if filename.endswith('_merged_chunks'): + for url, chunk_metadata in loaded_data.items(): + loaded_data[url] = [dict_to_dataclass(chunk, ChunkMetadata) for chunk in chunk_metadata] + incupd_logger.debug(f"Loaded {len(loaded_data)} temp merged chunks") + return loaded_data + + case _: + incupd_logger.info(f"Loaded data '{filename}'") + return loaded_data + + except Exception as e: + logger.error(f"Failed trying to load data '{filename}': {e}") + logger.info("New data will be recorded") + return defaultdict(dict) diff --git a/src/scraping/types.py b/src/scraping/types.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef9e16573a91c63ecd7d9b0194e7b73e6d944d3 --- /dev/null +++ b/src/scraping/types.py @@ -0,0 +1,120 @@ +from dataclasses import asdict, dataclass, is_dataclass +from datetime import datetime +from enum import Enum + +from docling_core.types.doc.document import DoclingDocument + +@dataclass +class FetchResult: + final_url: str + last_modified: datetime + etag: str + + not_modified: bool = False + text: str = '' + page_hash: str = '' + +@dataclass +class PageData: + url: str + last_modified: datetime + +@dataclass +class UrlTags: + topic: str + priority: str + language: str + program: str + +@dataclass +class UrlTimestamps: + last_modified: datetime = None + last_scraped: datetime = None + etag: str = "" + page_hash: str = "" + +@dataclass +class DocumentTags: + program: str + language: str + priority: str = "" + last_modified: datetime = None + last_scraped: datetime = None + +@dataclass +class TaggedDocument: + document: DoclingDocument + tags: DocumentTags + +@dataclass +class ChunkMetadata: + chunk_id: str + text: str + source_url: str + program: str + language: str + topic: str + last_scraped: datetime + page_title: str + section_heading: str + token_size: int + original_chunk_ids: list[str] = None + + +class ScrapingStatus(Enum): + OK = 1 + REJECTED = 2 + VISITED = 3 + REDIRECTION = 4 + NO_UPDATES = 5 + BLACKLISTED = 6 + + +@dataclass +class ScrapingResult: + final_url: str = "" + discovery_depth: int = 0 + discovered_urls: list[str] = None + document: DoclingDocument = None + timestamps: UrlTimestamps = None + status: ScrapingStatus = ScrapingStatus.NO_UPDATES + + +@dataclass +class DomainAnalysisReport: + target: str + pages: list[PageData] + urls: list[str] + delay: float + +@dataclass +class UrlAnalysisReport: + documents: list[DoclingDocument] + discovered_urls: list[str] + +@dataclass +class DocumentAnalysisReport: + url_tags: dict[str, UrlTags] + url_priorities: dict[str, list[str]] + tagged_documents: list[TaggedDocument] + + +def dataclass_to_dict(obj) -> dict: + if not is_dataclass(obj): return obj + return asdict(obj, dict_factory=lambda items: { + k: v.isoformat() if isinstance(v, datetime) else v + for k, v in items + }) + + +def dict_to_dataclass(data: dict, class_type): + from .utils import parse_isoformat + if not data: return None + + if 'last_scraped' in data.keys(): + data['last_scraped'] = parse_isoformat(data.get('last_scraped')) + + if 'last_modified' in data.keys(): + data['last_modified'] = parse_isoformat(data.get('last_modified')) + + return class_type(**data) diff --git a/src/scraping/url_normalizer.py b/src/scraping/url_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5a18239e87fcad514dbad1d8c5e15fd0d3dc16 --- /dev/null +++ b/src/scraping/url_normalizer.py @@ -0,0 +1,45 @@ +import re +from urllib.parse import urlsplit, urlparse +from ..const.page_blacklist import * + +class UrlNormalizer: + @staticmethod + def is_url_blacklisted(url: str) -> bool: + url_lower = url.lower() + path = url_lower.split('://', 1)[-1].split('/', 1)[-1] + + for forbidden in PAGE_BLACKLIST: + if forbidden in path: + return True + + return False + + + @staticmethod + def url_to_filename(url: str) -> str: + parsed = urlparse(url) + + # Build base from netloc + path + filename = f"{parsed.netloc}{parsed.path}" + + # Remove leading/trailing slashes + filename = filename.strip('/') + + # Replace separators + filename = filename.replace('/', '_').replace('.', '-') + + # Remove all problematic characters + filename = re.sub(r'[^a-zA-Z0-9_-]', '_', filename) + + return filename + + + def filter_discovered_urls(self, discovered_urls, visited_urls, target_domain) -> list: + filtered_urls = set() + + for url in discovered_urls: + if any([self.is_url_blacklisted(url), url in visited_urls, urlsplit(url).netloc != target_domain]): + continue + filtered_urls.add(url) + + return filtered_urls diff --git a/src/scraping/utils.py b/src/scraping/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9366e897c8f38bcf97a99118369c92c64883c53c --- /dev/null +++ b/src/scraping/utils.py @@ -0,0 +1,241 @@ +import hashlib +import json +import requests, difflib, datetime +from email.utils import parsedate_to_datetime +from functools import lru_cache +from collections import defaultdict +from urllib.robotparser import RobotFileParser +from urllib.error import URLError +from fake_useragent import UserAgent +from bs4 import BeautifulSoup + +from src.scraping.types import FetchResult + +from ..config import config +from ..const.page_priority import * +from ..utils.logging import get_logger +from ..utils.tools import call_with_exponential_backoff + +logger = get_logger('scraper.utils') +ua = UserAgent() + +@lru_cache +def _fuzzy_match(word, keyword, threshold=0.8): + """ + Check if word fuzzy matches keyword using difflib ratio. + """ + return difflib.SequenceMatcher(None, word.lower(), keyword.lower()).ratio() >= threshold + + +def detect_page_topic_and_priority(text: str) -> dict[str, str]: + result = { + 'priority': 'low', + 'topic': 'none', + } + + if not text: return result + + text_lower = text.lower() + words = text_lower.split() + topic_counter = { prio: defaultdict(int) for prio in PAGE_PRIORITY_KEYWORDS.keys() } + prio_counter = { prio: 0 for prio in PAGE_PRIORITY_KEYWORDS.keys() } + + for word in words: + for prio, kws in PAGE_PRIORITY_KEYWORDS.items(): + for kw in kws: + if _fuzzy_match(word, kw): + topic_counter[prio][kw] += 1 + prio_counter[prio] += sum(topic_counter[prio].values()) + + if max(prio_counter.values()) == 0: + return result + + top_prio = max(prio_counter.keys(), key=lambda k: prio_counter[k]) + top_topic = max(topic_counter[top_prio].keys(), key=lambda k: topic_counter[top_prio][k]) + + result['priority'] = top_prio + result['topic'] = top_topic + + return result + + +def detect_chunk_topic(text: str) -> str: + if not text: return 'none' + + text_lower = text.lower() + words = text_lower.split() + topic_counter = { topic: 0 for topic in CHUNK_TOPIC_KEYWORDS.keys() } + + for word in words: + for topic, kws in CHUNK_TOPIC_KEYWORDS.items(): + topic_counter[topic] += len(list(filter(lambda kw: _fuzzy_match(word, kw), kws))) + + if max(topic_counter.values()) == 0: + return 'none' + + top_topic = max(topic_counter.keys(), key=lambda k: topic_counter[k]) + return top_topic + + +def hash_html(html: str) -> str: + soup = BeautifulSoup(html, "html.parser") + + for tag in soup(["script", "style"]): + tag.decompose() + + text = soup.get_text() + return hashlib.sha256(text.encode()).hexdigest() + + +def parse_isoformat(data: str) -> datetime.datetime: + if not data: + return None + + try: + return parsedate_to_datetime(data) + except (TypeError, ValueError): + pass + + try: + return datetime.datetime.fromisoformat(data) + except ValueError: + pass + + return None + + +def extract_last_modified(response, html) -> datetime.datetime | None: + last_modified = response.headers.get("Last-Modified", None) + + soup = BeautifulSoup(html, "html.parser") + if not last_modified: + for key in [ ("name", "last-modified"), ("property", "article:modified_time")]: + tag = soup.find("meta", {key[0]: key[1]}) + if tag: + last_modified = tag.get("content") + break + + if not last_modified: + scripts = soup.find_all("script", {"type": "application/ld+json"}) + for script in scripts: + try: + data = json.loads(script.string) + except: + continue + + graph = data.get("@graph") if isinstance(data, dict) else None + + if graph: + for item in graph: + if item.get("@type") in ["WebPage", "Article"]: + last_modified = item.get("dateModified") + if last_modified: + break + + return parse_isoformat(last_modified) + + +def fetch_head(url: str, etag: str | None = None) -> FetchResult: + try: + headers = {"User-Agent": ua.chrome} + if etag: + headers["If-None-Match"] = etag + + response = requests.head( + url, + allow_redirects=True, + timeout=15, + headers=headers + ) + if response.status_code == 304: + return FetchResult(not_modified=True) + + if response.status_code >= 400: + logger.warning(f"HTTP {response.status_code} for URL '{url}'") + raise Exception() + + return FetchResult( + final_url = response.url, + last_modified = response.headers.get('Last-Modified'), + etag = response.headers.get('ETag') + ) + except Exception as e: + logger.exception(f"Head fetch failed: {url}") + raise e + + +def fetch_url(url: str, etag: str | None = None) -> dict: + try: + headers = {"User-Agent": ua.chrome} + if etag: + headers["If-None-Match"] = etag + + response = requests.get( + url, + allow_redirects=True, + timeout=15, + headers=headers + ) + if response.status_code == 304: + return FetchResult(not_modified=True) + + if response.status_code >= 400: + logger.warning(f"HTTP {response.status_code} for URL '{url}'") + raise Exception() + + html = response.text + etag = response.headers.get("ETag") + last_modified = extract_last_modified(response, html) + page_hash = hash_html(html) + + return FetchResult( + text = html, + final_url = response.url, + page_hash = page_hash, + last_modified = last_modified, + etag = etag, + ) + except Exception as e: + logger.exception(f"Fetch failed: {url}") + raise e + + +def _robots_exist(robots_url) -> bool: + try: + logger.info(f"Checking if 'robots.txt' accessible on path '{robots_url}'...") + response = requests.head(robots_url, allow_redirects=True, timeout=config.scraping.TIMEOUT) + if response.status_code >= 400: + logger.error("Cannot access the 'robots.txt' - recieved status code {response.status_code}!") + return False + return True + except requests.RequestException as e: + raise requests.RequestException(f"An error occured while requesting the URL '{robots_url}': {e}") + except Exception as e: + raise e + + +def parse_robots(base_url: str) -> RobotFileParser | None: + robots_url = f'{base_url.rstrip('/')}/robots.txt' + + # Check whether the robots.txt file is accessible from this url + response = call_with_exponential_backoff(_robots_exist, args=(robots_url,)) + if not response['result']: return None + + logger.info(f"File 'robots.txt' found for the target url '{base_url}'") + rp = RobotFileParser() + rp.set_url(robots_url) + + # Parse existing robots.txt file into the parser + def fetch_robots(): + try: + rp.read() + except URLError as e: + raise URLError(f"Failed to fetch the 'robots.txt': {e}") + + response = call_with_exponential_backoff(fetch_robots) + if response['status'] == 'FAIL': + logger.error(f"Failed to fetch the 'robots.txt': {response['last_error']}") + return None + + return rp + diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b5d7031e9d7c7ee53a63eefd37181113508a74e --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,3 @@ +""" +Utility modules for the Executive Education RAG Chatbot. +""" diff --git a/src/utils/lang.py b/src/utils/lang.py new file mode 100644 index 0000000000000000000000000000000000000000..26c743775b056dd52b372c430850759230cb8a37 --- /dev/null +++ b/src/utils/lang.py @@ -0,0 +1,30 @@ +from langdetect import DetectorFactory, detect_langs +from src.utils.logging import get_logger + +from src.config import config + +logger = get_logger('lang_utils') +DetectorFactory.seed = 0 + +def detect_language(text: str): + """ + Detects if the provided text is written in German or in some other language. + In case of ambiguous input returns 'en'. + + Args: + text (str): The text to analyze. + + Returns: + str: 'de' if the detection certanty is more than 0.6, else 'en'. + """ + found_langs = detect_langs(text) + top_lang = found_langs[0] + logger.debug(f'Found following languages in the text: {", ".join(f"{lang.lang}-{lang.prob:1.2f}" for lang in found_langs)}') + return 'de' if top_lang.lang == 'de' and top_lang.prob >= config.processing.LANG_AMBIGUITY_THRESHOLD else 'en' + + +def get_language_name(code: str): + return { + 'en': "British English", + 'de': "German", + }.get(code, 'British English') diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0d735a034d45e2085a5a78194d74a3f5881202 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,296 @@ +""" +Centralized logging configuration for the Executive Education RAG Chatbot. +""" +import logging, os, sys, warnings, colorama +from collections import defaultdict +from colorama import Fore, Style +from typing import Literal + +from src.config import config + +file_handlers = defaultdict(list) + +import json +from datetime import datetime, timezone +import os + +# Initialize colorama for cross-platform color support +colorama.init() + +class DefaultFormatter(logging.Formatter): + def format(self, record): + record = logging.makeLogRecord(record.__dict__) + + if hasattr(record, 'name'): + rname = record.name if len(record.name) <= 17 else record.name[:14] + '...' + record.name = rname + + return super().format(record) + + +class ColoredFormatter(logging.Formatter): + """Custom formatter with color support for console output ONLY. + Never mutates the original LogRecord (so file handlers stay clean).""" + + COLORS = { + 'DEBUG': Fore.CYAN, + 'INFO': Fore.GREEN, + 'WARNING': Fore.YELLOW, + 'ERROR': Fore.RED, + 'CRITICAL': Fore.MAGENTA + Style.BRIGHT, + } + ALIASES = { + 'DEBUG': 'DEBUG', + 'INFO': 'INFO ', + 'WARNING': 'WARN ', + 'ERROR': 'ERROR', + 'CRITICAL': 'CRITC' + } + + def format(self, record): + record = logging.makeLogRecord(record.__dict__) + + # Add color to the level name + if hasattr(record, 'levelname') and record.levelname in self.COLORS: + lname = record.levelname + color = self.COLORS[lname] + + if lname == 'ERROR' and hasattr(record, 'message'): + record.message = f"{color}{record.message}{Style.RESET_ALL}" + + record.levelname = f"{color}{self.ALIASES[lname]}{Style.RESET_ALL}" + + # Add color to the module name + if hasattr(record, 'name'): + rname = record.name if len(record.name) <= 17 else record.name[:14] + '...' + record.name = f"{Fore.CYAN}{rname}{Style.RESET_ALL}" + + return super().format(record) + + +def setup_logging(level: str = "INFO") -> logging.Logger: + """ + Set up centralized logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + Returns: + Configured logger instance + """ + os.makedirs(config.paths.LOGS, exist_ok=True) + + # Convert string level to logging constant + numeric_level = getattr(logging, level.upper(), logging.INFO) + + # Get root logger + logger = logging.getLogger() + + # Avoid duplicate handlers if logger already configured + if logger.handlers: + logger.handlers.clear() + + logger.setLevel(numeric_level) + + # Create formatters + detailed_formatter = DefaultFormatter( + "(%(asctime)s) %(name)s\t %(levelname)s: %(message)s", + datefmt="%Y.%m.%d %H:%M:%S" + ) + + colored_formatter = ColoredFormatter( + "(%(asctime)s) %(name)s\t %(levelname)s: %(message)s", + datefmt="%Y.%m.%d %H:%M:%S" + ) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(numeric_level) + + # Use colored formatter if terminal supports it + if _supports_color(): + console_handler.setFormatter(colored_formatter) + else: + console_handler.setFormatter(detailed_formatter) + + logger.addHandler(console_handler) + + return logger + + +def get_logger(module_name: str) -> logging.Logger: + """ + Get a logger for a specific module. + + Args: + module_name: Name of the module requesting the logger + + Returns: + Logger instance + """ + logger = logging.getLogger(module_name) + logger.propagate = True + return logger + + +def create_file_handler( + file_path: str, + module_name: str, + mode: Literal['a', 'w'] = 'a', + level = logging.WARNING +) -> logging.FileHandler: + """ + Initializes a new FileHandler to redirect logs to the files. + All subsequent calls to the 'append_handlers' function with the name of the module + will append handlers stored under the module name to the logger. + + Args: + file_path: path to the .log file where logs will be stored. + module_name: name of the logging module that this handler belongs to. + + Returns: + File handler instance. + """ + global file_handlers + + file_handler = logging.FileHandler( + file_path, + mode=mode, + encoding='utf-8' + ) + file_handler.setLevel(level) + + formatter = DefaultFormatter( + "(%(asctime)s) %(name)s\t %(levelname)s: %(message)s", + datefmt="%Y.%m.%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + + file_handlers[module_name].append(file_handler) + + return file_handler + + +def append_file_handlers(logger: logging.Logger, module_name: str) -> None: + global file_handlers + + for handler in file_handlers.get(module_name, []): + logger.addHandler(handler) + + +def _supports_color() -> bool: + """ + Check if the terminal supports color output. + + Returns: + True if color is supported, False otherwise + """ + # Check if we're in a terminal + if not hasattr(sys.stdout, 'isatty') or not sys.stdout.isatty(): + return False + + # Check environment variables + if os.getenv('NO_COLOR'): + return False + + if os.getenv('FORCE_COLOR'): + return True + + # Check terminal type + term = os.getenv('TERM', '').lower() + if 'color' in term or term in ['xterm', 'xterm-256color', 'screen']: + return True + + return False + + +def configure_external_loggers(level: str = "WARNING") -> None: + """ + Configure logging for external libraries to reduce noise. + + Args: + level: Logging level for external libraries + """ + external_loggers = [ + 'selenium', + 'urllib3', + 'requests', + 'chromadb', + 'docling', + 'docling_core', + 'uvicorn', + 'weaviate', + 'langchain', + 'langgraph', + 'openai', + 'httpx', + 'usp', + ] + + numeric_level = getattr(logging, level.upper(), logging.WARNING) + + for logger_name in external_loggers: + logging.getLogger(logger_name).setLevel(numeric_level) + + +def configure_internal_loggers(): + # Logging output for all loggers + root_handler = create_file_handler( + file_path=os.path.join(config.paths.LOGS, 'logs.log'), + module_name='*', + mode='a', + level=logging.INFO, + ) + root_logger = logging.getLogger() + root_logger.addHandler(root_handler) + + # Scraping loggers tree configuration + scraping_handler = create_file_handler( + file_path=os.path.join(config.paths.LOGS, 'scraping.log'), + module_name='scraping', + mode='w', + level=logging.INFO, + ) + scraping_logger = logging.getLogger('scraper') + scraping_logger.addHandler(scraping_handler) + + +# Global configuration function +def init_logging(level: str = "INFO") -> None: + """ + Initialize the global logging configuration. + + Args: + level: Logging level + log_file: Optional log file path + """ + warnings.filterwarnings("ignore") + + # Set up root logger + setup_logging(level=level) + + # Configure loggers defined by this application + configure_internal_loggers() + + # Configure external library loggers + configure_external_loggers() + + +class ConsentLogger: + def __init__(self): + log_dir = os.path.join('logs', 'consent') + os.makedirs(log_dir, exist_ok=True) + + def log(self, session_id: str, decision: str, policy_version="1.0"): + try: + entry = { + "session_id": session_id, + "decision": decision, + "timestamp": datetime.now(timezone.utc).isoformat(), + "policy_version": policy_version + } + + log_path = os.path.join('logs', 'consent', f"{session_id}.jsonl") + with open(log_path, "a", encoding="utf-8") as f: + f.write(json.dumps(entry, indent=2) + "\n") + + except Exception as e: + print(f"Error logging consent decision: {e}") diff --git a/src/utils/stratutils/generator.py b/src/utils/stratutils/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e90dd916fb2db12bbf8e44c1f22321bf8bed75 --- /dev/null +++ b/src/utils/stratutils/generator.py @@ -0,0 +1,8 @@ +from src.utils.stratutils.templates import * + +def generate_strategy(name, prop): + preamble = PREAMBLE_TEMPL_STD.format(name=name) + header = f"{FUNC_HEADER_TEMPL} -> {FUNC_RETURN_TYPE_TEMPL.get(prop['data_type'], None)}:" + body = BODY_TEMPL.get(name, BODY_TEMPL_STD) + + return f"{preamble}\n\n{header}\n{COMMENT_TEMPL_STD}\n\n{body}" diff --git a/src/utils/stratutils/templates.py b/src/utils/stratutils/templates.py new file mode 100644 index 0000000000000000000000000000000000000000..7293117fff21c5850cdf25c23ebc041054ca33fc --- /dev/null +++ b/src/utils/stratutils/templates.py @@ -0,0 +1,31 @@ +FUNC_HEADER_TEMPL = "def run(file_name: str, file_content: str, chunk: str)" + +FUNC_RETURN_TYPE_TEMPL = { + "text": "str", + "date": "str", + "text[]": "list[str]", +} + +PREAMBLE_TEMPL_STD="""\"\"\"Property extraction strategy for property {name}.\"\"\"""" + +COMMENT_TEMPL_STD = """\t\"\"\" +\tRuns the property extraction strategy on processed chunk. + +\tArgs: +\t\tfile_name (str): Name of the file from which the chunk was collected. +\t\tfile_content (str): Entire text extracted from file. +\t\tchunk (str): Chunk collected from file. + +\tReturns: +\t\tExtracted property. +\t\"\"\"""" + +BODY_TEMPL_STD = "\treturn chunk" + +BODY_TEMPL = { + 'body': "\treturn chunk", + 'source': "\treturn file_name", + 'chunk_id': "\timport hashlib\n\treturn hashlib.md5(chunk.strip().encode('utf-8')).hexdigest()", + 'document_id': "\timport hashlib\n\treturn hashlib.md5(file_content.strip().encode('utf-8')).hexdigest()", + 'date': "\timport datetime\n\treturn datetime.datetime.now().replace(tzinfo=datetime.timezone.utc)" +} diff --git a/src/utils/tools.py b/src/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..1a73bd52c62ec6950055154b5a5311504a7dd0e0 --- /dev/null +++ b/src/utils/tools.py @@ -0,0 +1,34 @@ +from time import sleep + +from .logging import get_logger +from ..config import config + +logger = get_logger("utils.backoff") + +def call_with_exponential_backoff( + func, + args: tuple = (), + delay: float | None = None, + backoff_rate: float | None = None, +) -> dict: + retries = 0 + last_error = None + + delay = delay or config.scraping.CRAWL_DELAY + backoff_rate = backoff_rate or config.scraping.BACKOFF_RATE + + sleep(delay) + + while retries <= config.scraping.MAX_RETRIES: + try: + return { 'result': func(*args), 'retries': retries, 'last_error': last_error, 'status': 'OK'} + except Exception as e: + logger.warning(f'Caught an error on try {retries+1}: {e}') + last_error = e + retries += 1 + + backoff_time = delay * backoff_rate**retries + logger.info(f'Retrying with exponential backoff time {backoff_time} sec.') + sleep(backoff_time) + + return { 'result': None, 'retries': retries, 'last_error': last_error, 'status': 'FAIL' }