Park-Hip-02 commited on
Commit
d9762cf
·
1 Parent(s): 7069e9e

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ../.env
2
+ .venv/
3
+ .deleted/
4
+ test.ipynb
5
+ test.py
6
+ aliases/
7
+ raft_state.json
8
+ mlartifacts/
9
+ mlruns/
10
+ .deleted/
11
+ secret.py
12
+ docker-compose.yml
13
+ .env
.idea/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
.idea/Legal-RAG-Chatbot.iml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/.venv" />
6
+ </content>
7
+ <orderEntry type="jdk" jdkName="Python 3.13 (Legal-RAG-Chatbot)" jdkType="Python SDK" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.13 (Legal-RAG-Chatbot)" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (Legal-RAG-Chatbot)" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Legal-RAG-Chatbot.iml" filepath="$PROJECT_DIR$/.idea/Legal-RAG-Chatbot.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Legal RAG Chatbot2
3
- emoji: 🏢
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: hello
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Legal RAG Chatbot
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: A RAG Chatbot that can answer legal questions
12
  ---
13
 
14
+ An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from rag.rag_production import get_rag_chain
3
+
4
+ def rag_fn(model_name: str, input: str):
5
+ try:
6
+ rag_chain = get_rag_chain(model_name=model_name)
7
+ response_stream = rag_chain.stream({'input': input})
8
+
9
+ full_answer = ''
10
+ for chunk in response_stream:
11
+ if 'answer' in chunk and chunk['answer'] is not None:
12
+ answer_piece = chunk['answer']
13
+ full_answer += answer_piece
14
+ yield full_answer
15
+ except Exception as e:
16
+ import traceback
17
+ print(traceback.format_exc())
18
+ yield f"An error occurred: {e}"
19
+
20
+ interface = gr.Interface(
21
+ fn = rag_fn,
22
+ inputs = [
23
+ gr.Dropdown(choices=['llama-3.3-70b-versatile', 'openai/gpt-oss-120b'], label="MODEL"),
24
+ gr.Textbox(label='QUESTION'),
25
+ ],
26
+ outputs = gr.Textbox(label='ANSWER'),
27
+ title = "Legal RAG Chatbot",
28
+ description = "Select a model and ask a question to get an answer from the RAG system.",
29
+ examples = [
30
+ ['llama-3.3-70b-versatile', 'What is the maximum duration of determinate imprisonment that can be imposed on an offender?'],
31
+ ['openai/gpt-oss-120b','If someone voluntarily pays damages after committing a crime, how might this affect their sentencing?']
32
+ ],
33
+ cache_examples=False
34
+ )
35
+
36
+ if __name__ == "__main__":
37
+ interface.queue()
38
+ interface.launch()
data/processed_data/criminal_code_of_vietnam.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1eeb278744619c4a50994e9d04247094daeb1c2a0870eb54deba9657de4e6cb
3
+ size 733443
embeddings/__pycache__/embedder.cpython-313.pyc ADDED
Binary file (4.89 kB). View file
 
embeddings/embedder.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ from langchain_qdrant import QdrantVectorStore, RetrievalMode
4
+ from qdrant_client import QdrantClient, models
5
+ import logging
6
+ import pickle
7
+ from pathlib import Path
8
+
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s — %(levelname)s — %(message)s',
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def get_vectorstore() -> QdrantVectorStore:
16
+ base_dir = Path(__file__).resolve().parent.parent
17
+ doc_path = base_dir / 'data' / 'processed_data' / 'criminal_code_of_vietnam.pkl'
18
+
19
+ with open(doc_path, 'rb') as f:
20
+ doc_list = pickle.load(f)
21
+
22
+ qdrant_api_key = os.getenv('QDRANT_API_KEY')
23
+ qdrant_url = os.getenv('QDRANT_URL')
24
+ hf_api_key = os.getenv('HUGGINGFACEHUB_API_TOKEN')
25
+
26
+ collection_name = 'legal_db'
27
+ client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
28
+
29
+ model_name = 'BAAI/bge-large-en'
30
+ model_kwargs = {'device': 'cpu'}
31
+ encode_kwargs = {'normalize_embeddings': False}
32
+
33
+ embeddings = HuggingFaceEmbeddings(
34
+ model_name=model_name,
35
+ model_kwargs=model_kwargs,
36
+ encode_kwargs=encode_kwargs
37
+ )
38
+ logger.info('Embedding created.')
39
+ dummy_embedding = embeddings.embed_query('A dummy to test embedding dimension')
40
+ vector_dim = len(dummy_embedding)
41
+
42
+ vectors_config = models.VectorParams(size=vector_dim, distance=models.Distance.COSINE)
43
+
44
+ if collection_name in [c.name for c in client.get_collections().collections]:
45
+ logger.info('Collection exists. Connecting...')
46
+ collection_info = client.get_collection(collection_name)
47
+ existing_dim = None
48
+ if hasattr(collection_info.config, 'vectors') and hasattr(collection_info.config.vectors, 'size'):
49
+ existing_dim = collection_info.config.vectors.size
50
+ elif hasattr(collection_info.config, 'params') and hasattr(collection_info.config.params, 'vectors') and hasattr(collection_info.config.params.vectors, 'size'):
51
+ existing_dim = collection_info.config.params.vectors.size
52
+
53
+ logger.info(f'Existing dimension: {existing_dim}')
54
+ if existing_dim != vector_dim:
55
+ raise ValueError(
56
+ f'Dimension mismatch: existing collection has {existing_dim}, but embedding model gives {vector_dim}'
57
+ )
58
+
59
+ db = QdrantVectorStore.from_existing_collection(
60
+ embedding=embeddings,
61
+ collection_name=collection_name,
62
+ prefer_grpc=False,
63
+ url=qdrant_url,
64
+ api_key = qdrant_api_key
65
+ )
66
+ else:
67
+ logger.info(f'Collection "{collection_name}" does not exist. Creating new collection...')
68
+ client.create_collection(
69
+ collection_name=collection_name,
70
+ vectors_config=vectors_config,
71
+
72
+ )
73
+ db = QdrantVectorStore.from_documents(
74
+ documents=doc_list,
75
+ embedding=embeddings,
76
+ url=qdrant_url,
77
+ prefer_grpc=False,
78
+ collection_name=collection_name,
79
+ retrieval_mode = RetrievalMode.DENSE,
80
+ api_key = qdrant_api_key
81
+ )
82
+ logger.info('Qdrant Index created.')
83
+
84
+ fields_to_index = {
85
+ 'metadata.article': "keyword",
86
+ 'metadata.chapter': "keyword",
87
+ 'metadata.id': "keyword",
88
+ 'metadata.source': "keyword",
89
+ 'metadata.title': "keyword",
90
+ }
91
+
92
+ for field, schema in fields_to_index.items():
93
+ client.create_payload_index(
94
+ collection_name = collection_name,
95
+ field_name = field,
96
+ field_schema = schema,
97
+ )
98
+
99
+ return db
100
+
101
+
102
+
rag/__pycache__/rag_production.cpython-312.pyc ADDED
Binary file (3.02 kB). View file
 
rag/__pycache__/rag_production.cpython-313.pyc ADDED
Binary file (2.99 kB). View file
 
rag/rag_production.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains.combine_documents import create_stuff_documents_chain
2
+ from langchain.chains.retrieval import create_retrieval_chain
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ import os
6
+ from langchain_core.prompts import PromptTemplate
7
+ import logging
8
+ from embeddings.embedder import get_vectorstore
9
+
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelName)s - %(message)s'
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ rag_prompt_template = '''
19
+ You are a legal assistant trained to answer questions using legal documents.
20
+
21
+ If the answer cannot be determined from the available legal text, you must answer without include the
22
+ According to Article <article>, Chapter <chapter>, <title> phrase, for example:
23
+ > The answer cannot be determined from the available legal text.
24
+
25
+ Otherwise, you must STRICTLY follow this 3-step structure:
26
+
27
+ 1. **Begin your response ONLY with this format** (fill in the values from metadata):
28
+ > According to Article <article>, Chapter <chapter>, <title>:
29
+
30
+
31
+ 2. **Then, extract and summarize the most relevant point** from the provided context.
32
+
33
+ 3. **Finally, answer the user’s question clearly and formally**, referring only to the point above.
34
+
35
+ ---
36
+ 💥 IMPORTANT RULES:
37
+ - Do NOT repeat or paraphrase the law reference later.
38
+ - Do NOT invent any legal information — use ONLY the provided context and metadata.
39
+ - Do NOT add phrase like 'The most relevant point is that' when mentioning the context.
40
+ - If the context is insufficient, respond with:
41
+ > "The answer cannot be determined from the available legal text."
42
+ ---
43
+
44
+ **User Question:**
45
+ {input}
46
+
47
+ **Retrieved Legal Context:**
48
+ {context}
49
+ '''
50
+
51
+ document_prompt = PromptTemplate.from_template(
52
+ '''
53
+ Article: {article}
54
+ Chapter: {chapter}
55
+ Title: {title}
56
+ Content: {page_content}
57
+ '''
58
+ )
59
+
60
+ prompt = ChatPromptTemplate.from_messages([
61
+ ('system', rag_prompt_template),
62
+ ('user', "Context:\n{context}\n\nQuestion:\n{input}\n\nAnswer:")
63
+ ])
64
+
65
+ def get_rag_chain(model_name='llama-3.3-70b-versatile', k=1):
66
+ db = get_vectorstore()
67
+
68
+ groq_api_key = os.getenv('GROQ_API_KEY')
69
+
70
+ llm = ChatGroq(
71
+ model_name=model_name,
72
+ temperature=0,
73
+ max_tokens=10000,
74
+ api_key=groq_api_key,
75
+ )
76
+
77
+ retriever = db.as_retriever(
78
+ search_type = 'similarity',
79
+ search_kwargs = {'k': k}
80
+ )
81
+
82
+ combine_doc_chain = create_stuff_documents_chain(
83
+ prompt=prompt,
84
+ llm=llm,
85
+ document_prompt=document_prompt,
86
+ )
87
+ return create_retrieval_chain(retriever, combine_doc_chain)
88
+
89
+
90
+
91
+
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.38.0
2
+ fastapi>=0.111.0
3
+ uvicorn>=0.28.0
4
+ langchain
5
+ langchain-community
6
+ langchain-qdrant
7
+ huggingface_hub
8
+ qdrant-client
9
+ sentence-transformers
10
+ transformers
11
+ torch
12
+ langchain_groq
13
+ langchain_huggingface