Ahmed-Alghamdi commited on
Commit
e820a8a
·
verified ·
1 Parent(s): 7da3b85

Upload 11 files

Browse files
Files changed (11) hide show
  1. RAG DATA.txt +0 -0
  2. README +49 -0
  3. config.py +10 -0
  4. document_processor.py +19 -0
  5. embedding_generator.py +20 -0
  6. main.py +79 -0
  7. requirements.txt +26 -0
  8. response_generator.py +73 -0
  9. search_engine.py +30 -0
  10. streamlit_app.py +63 -0
  11. utils.py +19 -0
RAG DATA.txt ADDED
The diff for this file is too large to render. See raw diff
 
README ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ArabicRAG: Arabic Retrieval-Augmented Generation
3
+
4
+ ### Project Structure
5
+ # arabic_legal_search/
6
+ # ├── config.py
7
+ # ├── document_processor.py
8
+ # ├── embedding_generator.py
9
+ # ├── search_engine.py
10
+ # ├── response_generator.py
11
+ # ├── utils.py
12
+ # ├── main.py
13
+ # └── requirements.txt
14
+ ## Overview
15
+ ArabicRAG is an open-source project designed to leverage the power of retrieval-augmented generation for processing and understanding Arabic legal documents. The system integrates advanced NLP techniques to retrieve relevant documents and generate context-aware responses.
16
+
17
+ ## Features
18
+ - **Document Processing**: Load and preprocess Arabic text documents efficiently.
19
+ - **Embedding Generation**: Utilize multilingual models to generate embeddings for Arabic text.
20
+ - **Efficient Search**: Leverage FAISS for fast and efficient similarity search in large document corpora.
21
+ - **Response Generation**: Use state-of-the-art transformer models to generate responses based on retrieved context.
22
+
23
+ ## Installation
24
+ To set up your environment and run ArabicRAG, follow these steps:
25
+
26
+ 1. Clone the repository:
27
+ ```bash
28
+ git clone https://github.com/maljefairi/arabicRAG
29
+ ```
30
+ 2. Install the required packages:
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ ## Usage
36
+ After installation, you can run the main script to start processing documents:
37
+ ```bash
38
+ python main.py
39
+ ```
40
+
41
+ ## Contributing
42
+ Contributions are welcome! For major changes, please open an issue first to discuss what you would like to change. Please make sure to update tests as appropriate.
43
+
44
+ ## License
45
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
46
+
47
+ ## Contact
48
+ - **Dr. Mohammed Al-Jefairi** - maljefairi@sidramail.com
49
+ - **GitHub**: [maljefairi](https://github.com/maljefairi/arabicRAG)
config.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import os
3
+
4
+ class Config:
5
+ DOCUMENT_FOLDER = os.environ.get('DOCUMENT_FOLDER', 'data') EMBEDDING_MODEL = os.environ.get('EMBEDDING_MODEL', 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
6
+ LLM_MODEL = os.environ.get('LLM_MODEL', 'CAMeL-Lab/bert-base-arabic-camelbert-ca')
7
+ BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 32))
8
+ TOP_K = int(os.environ.get('TOP_K', 5))
9
+ MAX_LENGTH = int(os.environ.get('MAX_LENGTH', 1024))
10
+
document_processor.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # document_processor.py
2
+ import os
3
+ import glob
4
+ from tqdm import tqdm
5
+ import pandas as pd
6
+ from utils import clean_text, setup_logger
7
+
8
+ logger = setup_logger('document_processor')
9
+
10
+ def load_documents(folder_path):
11
+ documents = []
12
+ for file_path in tqdm(glob.glob(os.path.join(folder_path, '*.txt')), desc="Loading documents"):
13
+ try:
14
+ with open(file_path, 'r', encoding='utf-8') as file:
15
+ content = clean_text(file.read())
16
+ documents.append({'path': file_path, 'content': content})
17
+ except Exception as e:
18
+ logger.error(f"Error reading {file_path}: {e}")
19
+ return pd.DataFrame(documents)
embedding_generator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embedding_generator.py
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from sentence_transformers import SentenceTransformer
5
+ from utils import setup_logger
6
+ from config import Config
7
+
8
+ logger = setup_logger('embedding_generator')
9
+
10
+ def generate_embeddings(documents):
11
+ model = SentenceTransformer(Config.EMBEDDING_MODEL)
12
+ embeddings = []
13
+ for i in tqdm(range(0, len(documents), Config.BATCH_SIZE), desc="Generating embeddings"):
14
+ batch = documents['content'][i:i+Config.BATCH_SIZE].tolist()
15
+ try:
16
+ batch_embeddings = model.encode(batch, show_progress_bar=False)
17
+ embeddings.extend(batch_embeddings)
18
+ except Exception as e:
19
+ logger.error(f"Error encoding batch: {e}")
20
+ return np.array(embeddings)
main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import os
3
+ import sys
4
+ from document_processor import load_documents
5
+ from embedding_generator import generate_embeddings
6
+ from search_engine import SearchEngine
7
+ from response_generator import ResponseGenerator
8
+ from config import Config
9
+ from utils import setup_logger
10
+
11
+ logger = setup_logger('main')
12
+
13
+ def initialize_system():
14
+ """Initialize the search and response system."""
15
+ logger.info("Initializing the system...")
16
+
17
+ # Load and process documents
18
+ documents = load_documents(Config.DOCUMENT_FOLDER)
19
+ logger.info(f"Loaded {len(documents)} documents")
20
+
21
+ # Generate embeddings
22
+ embeddings = generate_embeddings(documents)
23
+ logger.info(f"Generated embeddings of shape {embeddings.shape}")
24
+
25
+ # Initialize search engine
26
+ search_engine = SearchEngine(documents, embeddings)
27
+ logger.info("Search engine initialized")
28
+
29
+ # Initialize response generator
30
+ response_generator = ResponseGenerator()
31
+ logger.info("Response generator initialized")
32
+
33
+ return search_engine, response_generator
34
+
35
+ def process_query(query, search_engine, response_generator):
36
+ """Process a single query and return the response."""
37
+ relevant_docs = search_engine.search(query)
38
+ logger.info(f"Found {len(relevant_docs)} relevant documents")
39
+
40
+ response = response_generator.generate_response(query, relevant_docs)
41
+ return response
42
+
43
+ def interactive_mode(search_engine, response_generator):
44
+ """Run the system in interactive mode, processing queries from user input."""
45
+ print("Enter your queries. Type 'quit' to exit.")
46
+ while True:
47
+ query = input("Query: ").strip()
48
+ if query.lower() == 'quit':
49
+ break
50
+
51
+ response = process_query(query, search_engine, response_generator)
52
+ print(f"Response: {response}\n")
53
+
54
+ def batch_mode(input_file, output_file, search_engine, response_generator):
55
+ """Process queries from an input file and write responses to an output file."""
56
+ with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile:
57
+ for line in infile:
58
+ query = line.strip()
59
+ response = process_query(query, search_engine, response_generator)
60
+ outfile.write(f"Query: {query}\nResponse: {response}\n\n")
61
+ logger.info(f"Batch processing completed. Results written to {output_file}")
62
+
63
+ def main():
64
+ search_engine, response_generator = initialize_system()
65
+
66
+ if len(sys.argv) > 1:
67
+ if sys.argv[1] == '--batch':
68
+ if len(sys.argv) != 4:
69
+ print("Usage for batch mode: python main.py --batch input_file output_file")
70
+ sys.exit(1)
71
+ batch_mode(sys.argv[2], sys.argv[3], search_engine, response_generator)
72
+ else:
73
+ print("Unknown argument. Use --batch for batch mode or no arguments for interactive mode.")
74
+ sys.exit(1)
75
+ else:
76
+ interactive_mode(search_engine, response_generator)
77
+
78
+ if __name__ == "__main__":
79
+ main()
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ # Standard libraries for data handling and computation
3
+ numpy==1.26.4
4
+ pandas==1.3.5
5
+ scipy==1.14.1
6
+
7
+ # Machine Learning and NLP libraries
8
+ torch==2.4.0+cu116 # Specify CUDA version if needed
9
+ torchvision==0.19.0+cu116 # Specify CUDA version if needed
10
+ transformers==4.15.0
11
+ sentence-transformers==2.2.0
12
+
13
+ # For efficient nearest neighbor search
14
+ faiss-gpu==1.8.0 # Make sure to use the GPU version for CUDA compatibility
15
+
16
+ # Utilities
17
+ tqdm==4.62.3
18
+ setuptools_rust # Required for compiling tokenizers with Rust dependencies
19
+ tokenizers==0.10.3 # Ensure compatibility with transformers
20
+
21
+ # If you're using Rust-based libraries
22
+ rust==2024.1 # Pseudo-version, adjust based on your need or omit if not using Rust directly
23
+
24
+ # Additional Python packages that may be required
25
+ huggingface-hub==0.24.6
26
+ streamlit
response_generator.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # response_generator.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from utils import setup_logger
5
+ from config import Config
6
+
7
+ logger = setup_logger('response_generator')
8
+
9
+ class ResponseGenerator:
10
+ def __init__(self):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(Config.LLM_MODEL)
12
+ self.model = AutoModelForCausalLM.from_pretrained(Config.LLM_MODEL)
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ logger.info(f"Model loaded and moved to {self.device}")
16
+
17
+ def generate_response(self, query, relevant_docs):
18
+ try:
19
+ context = self._prepare_context(relevant_docs)
20
+ prompt = self._create_prompt(query, context)
21
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
22
+
23
+ attention_mask = input_ids.ne(self.tokenizer.pad_token_id).float()
24
+
25
+ with torch.no_grad():
26
+ output = self.model.generate(
27
+ input_ids,
28
+ attention_mask=attention_mask,
29
+ max_length=Config.MAX_LENGTH,
30
+ num_return_sequences=1,
31
+ no_repeat_ngram_size=2,
32
+ do_sample=True,
33
+ top_k=50,
34
+ top_p=0.95,
35
+ temperature=0.7
36
+ )
37
+
38
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
39
+ return self._extract_answer(response)
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error generating response: {e}")
43
+ return "عذرًا، لم أتمكن من إنشاء استجابة بسبب خطأ ما." # "Sorry, I couldn't generate a response due to an error."
44
+
45
+ def _prepare_context(self, relevant_docs):
46
+ # Combine content from relevant documents
47
+ combined_content = "\n".join(relevant_docs['content'].tolist())
48
+ # Truncate if too long
49
+ max_context_length = Config.MAX_LENGTH // 2 # Use half of max_length for context
50
+ return combined_content[:max_context_length]
51
+
52
+ def _create_prompt(self, query, context):
53
+ return f"""مستند قانوني:
54
+ {context}
55
+
56
+ سؤال:
57
+ {query}
58
+
59
+ إجابة:"""
60
+
61
+ def _extract_answer(self, response):
62
+ # Extract the generated answer from the full response
63
+ answer_start = response.find("إجابة:") + len("إجابة:")
64
+ return response[answer_start:].strip()
65
+
66
+ def update_model(self, new_model_name):
67
+ try:
68
+ self.tokenizer = AutoTokenizer.from_pretrained(new_model_name)
69
+ self.model = AutoModelForCausalLM.from_pretrained(new_model_name)
70
+ self.model.to(self.device)
71
+ logger.info(f"Model updated to {new_model_name}")
72
+ except Exception as e:
73
+ logger.error(f"Error updating model: {e}")
search_engine.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # search_engine.py
2
+ import faiss
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from utils import setup_logger
6
+ from config import Config
7
+
8
+ logger = setup_logger('search_engine')
9
+
10
+ class SearchEngine:
11
+ def __init__(self, documents, embeddings):
12
+ self.documents = documents
13
+ self.index = self._build_faiss_index(embeddings)
14
+ self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
15
+
16
+ def _build_faiss_index(self, embeddings):
17
+ dimension = embeddings.shape[1]
18
+ index = faiss.IndexFlatL2(dimension)
19
+ index.add(embeddings.astype('float32'))
20
+ return index
21
+
22
+ def search(self, query):
23
+ try:
24
+ query_embedding = self.model.encode([query])
25
+ _, indices = self.index.search(query_embedding.astype('float32'), Config.TOP_K)
26
+ return self.documents.iloc[indices[0]]
27
+ except Exception as e:
28
+ logger.error(f"Error searching documents: {e}")
29
+ return pd.DataFrame()
30
+
streamlit_app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from main import initialize_system, process_query
3
+
4
+ # ---------------------------------------------------------
5
+ # Streamlit page configuration
6
+ # ---------------------------------------------------------
7
+ st.set_page_config(
8
+ page_title="Arabic RAG Chatbot 🤖",
9
+ page_icon="🤖",
10
+ layout="wide",
11
+ )
12
+
13
+ # ---------------------------------------------------------
14
+ # Title and description
15
+ # ---------------------------------------------------------
16
+ st.title("🤖 Arabic RAG Chatbot")
17
+ st.markdown("""
18
+ مرحبًا! 👋
19
+ اكتب سؤالك بالعربية وسيتولى النظام الإجابة استنادًا إلى مستنداتك المخزّنة محليًا.
20
+ """)
21
+
22
+ # ---------------------------------------------------------
23
+ # Cached system initialization (so it doesn't reload every time)
24
+ # ---------------------------------------------------------
25
+ @st.cache_resource
26
+ def load_rag_system():
27
+ search_engine, response_generator = initialize_system()
28
+ return search_engine, response_generator
29
+
30
+ search_engine, response_generator = load_rag_system()
31
+
32
+ # ---------------------------------------------------------
33
+ # Input section
34
+ # ---------------------------------------------------------
35
+ st.divider()
36
+ query = st.text_input("📝 أدخل سؤالك هنا:", placeholder="مثال: ما هي نسبة الحضور المطلوبة؟")
37
+
38
+ # ---------------------------------------------------------
39
+ # Query handling
40
+ # ---------------------------------------------------------
41
+ if st.button("بحث") or query:
42
+ if not query.strip():
43
+ st.warning("يرجى كتابة سؤال أولاً.")
44
+ else:
45
+ with st.spinner("⏳ جارٍ البحث عن الإجابة..."):
46
+ try:
47
+ response = process_query(query, search_engine, response_generator)
48
+ if response:
49
+ st.success("💬 الإجابة:")
50
+ st.write(response)
51
+ else:
52
+ st.info("لم يتم العثور على إجابة ذات صلة في المستندات.")
53
+ except Exception as e:
54
+ st.error(f"حدث خطأ أثناء توليد الإجابة: {e}")
55
+
56
+ # ---------------------------------------------------------
57
+ # Footer
58
+ # ---------------------------------------------------------
59
+ st.divider()
60
+ st.markdown(
61
+ "<p style='text-align:center; color:gray;'>تم التطوير باستخدام Streamlit و RAG ❤️</p>",
62
+ unsafe_allow_html=True
63
+ )
utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ import re
3
+ import logging
4
+
5
+ def clean_text(text):
6
+ # Remove special characters and multiple spaces
7
+ text = re.sub(r'[^\w\s\u0600-\u06FF]', ' ', text)
8
+ text = re.sub(r'\s+', ' ', text).strip()
9
+ return text
10
+
11
+ def setup_logger(name):
12
+ logger = logging.getLogger(name)
13
+ logger.setLevel(logging.INFO)
14
+ handler = logging.StreamHandler()
15
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
16
+ handler.setFormatter(formatter)
17
+ logger.addHandler(handler)
18
+ return logger
19
+