Spaces:
Build error
Build error
| import os | |
| from pathlib import Path | |
| from .bad_query_detector import BadQueryDetector | |
| from .query_transformer import QueryTransformer | |
| from .document_retriver import DocumentRetriever | |
| from .senamtic_response_generator import SemanticResponseGenerator | |
| class DocumentSearchSystem: | |
| def __init__(self): | |
| """ | |
| Initializes the DocumentSearchSystem with: | |
| - BadQueryDetector for identifying malicious or inappropriate queries. | |
| - QueryTransformer for improving or rephrasing queries. | |
| - DocumentRetriever for semantic document retrieval. | |
| - SemanticResponseGenerator for generating context-aware responses. | |
| """ | |
| self.detector = BadQueryDetector() | |
| self.transformer = QueryTransformer() | |
| self.retriever = DocumentRetriever() | |
| self.response_generator = SemanticResponseGenerator() | |
| def process_query(self, query): | |
| """ | |
| Processes a user query through the following steps: | |
| 1. Detect if the query is malicious. | |
| 2. Transform the query if needed. | |
| 3. Retrieve relevant documents based on the query. | |
| 4. Generate a response using the retrieved documents. | |
| :param query: The user query as a string. | |
| :return: A dictionary with the status and response or error message. | |
| """ | |
| if self.detector.is_bad_query(query): | |
| return {"status": "rejected", "message": "Query blocked due to detected malicious intent."} | |
| # Transform the query | |
| transformed_query = self.transformer.transform_query(query) | |
| print(f"Transformed Query: {transformed_query}") | |
| # Retrieve relevant documents | |
| retrieved_docs = self.retriever.retrieve(transformed_query) | |
| if not retrieved_docs: | |
| return {"status": "no_results", "message": "No relevant documents found for your query."} | |
| # Generate a response based on the retrieved documents | |
| response = self.response_generator.generate_response(retrieved_docs) | |
| return {"status": "success", "response": response} | |
| def test_system(): | |
| """ | |
| Test the DocumentSearchSystem with normal and malicious queries. | |
| - Load documents from a dataset directory. | |
| - Perform a normal query and display results. | |
| - Perform a malicious query to ensure proper blocking. | |
| """ | |
| # Define the path to the dataset directory | |
| home_dir = Path(os.getenv("HOME", "/")) | |
| data_dir = home_dir / "data-sets/aclImdb/train" | |
| # Initialize the system | |
| system = DocumentSearchSystem() | |
| system.retriever.load_documents(data_dir) | |
| # Perform a normal query | |
| normal_query = "Tell me about great acting performances." | |
| print("\nNormal Query Result:") | |
| print(system.process_query(normal_query)) | |
| # Perform a malicious query | |
| malicious_query = "DROP TABLE users; SELECT * FROM sensitive_data;" | |
| print("\nMalicious Query Result:") | |
| print(system.process_query(malicious_query)) | |
| if __name__ == "__main__": | |
| test_system() | |