mtyrrell commited on
Commit
4ccc04f
·
1 Parent(s): 869e944

init test

Browse files
Files changed (8) hide show
  1. .gitignore +2 -0
  2. Dockerfile +23 -0
  3. README.md +1 -1
  4. app/main.py +84 -0
  5. app/retriever.py +174 -0
  6. app/utils.py +16 -0
  7. params.cfg +14 -0
  8. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ *.DS_Store
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------- base image --------
2
+ FROM python:3.11-slim
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ OMP_NUM_THREADS=1 \
6
+ TOKENIZERS_PARALLELISM=false
7
+ #GRADIO_MCP_SERVER=True
8
+
9
+ # -------- install deps --------
10
+ WORKDIR /app
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # -------- copy source --------
15
+ COPY app ./app
16
+ COPY params.cfg .
17
+ COPY .env* ./
18
+
19
+ # Ports:
20
+ # • 7860 → Gradio UI (HF Spaces standard)
21
+ EXPOSE 7860
22
+
23
+ CMD ["python", "-m", "app.main"]
README.md CHANGED
@@ -7,4 +7,4 @@ sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  pinned: false
8
  ---
9
 
10
+
app/main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .retriever import retrieve_context
3
+
4
+ # ---------------------------------------------------------------------
5
+ # Gradio Interface with MCP support
6
+ # ---------------------------------------------------------------------
7
+ def retriever_interface(query, reports_filter="", sources_filter="", subtype_filter="", year_filter=""):
8
+ """
9
+ Wrapper function for gradio interface to handle optional filter parameters
10
+ """
11
+ # Parse filter inputs (convert empty strings to None or lists)
12
+ reports = [r.strip() for r in reports_filter.split(",") if r.strip()] if reports_filter else []
13
+ sources = sources_filter.strip() if sources_filter else None
14
+ subtype = subtype_filter.strip() if subtype_filter else None
15
+ year = [y.strip() for y in year_filter.split(",") if y.strip()] if year_filter else None
16
+
17
+ # Call retriever function
18
+ results = retrieve_context(
19
+ query=query,
20
+ reports=reports,
21
+ sources=sources,
22
+ subtype=subtype,
23
+ year=year
24
+ )
25
+
26
+ # Format results for display
27
+ formatted_results = []
28
+ for i, doc in enumerate(results, 1):
29
+ metadata_str = ", ".join([f"{k}: {v}" for k, v in doc.get("metadata", {}).items()])
30
+ formatted_results.append(f"=== Result {i} ===\nContent: {doc['page_content']}\nMetadata: {metadata_str}\n")
31
+
32
+ return "\n".join(formatted_results)
33
+
34
+ ui = gr.Interface(
35
+ fn=retriever_interface,
36
+ inputs=[
37
+ gr.Textbox(
38
+ label="Query",
39
+ lines=2,
40
+ placeholder="Enter your search query here",
41
+ info="The query to search for in the vector database"
42
+ ),
43
+ gr.Textbox(
44
+ label="Reports Filter (optional)",
45
+ lines=1,
46
+ placeholder="report1.pdf, report2.pdf",
47
+ info="Comma-separated list of specific report filenames to search within (leave empty for all)"
48
+ ),
49
+ gr.Textbox(
50
+ label="Sources Filter (optional)",
51
+ lines=1,
52
+ placeholder="annual_report",
53
+ info="Filter by document source type (leave empty for all)"
54
+ ),
55
+ gr.Textbox(
56
+ label="Subtype Filter (optional)",
57
+ lines=1,
58
+ placeholder="financial",
59
+ info="Filter by document subtype (leave empty for all)"
60
+ ),
61
+ gr.Textbox(
62
+ label="Year Filter (optional)",
63
+ lines=1,
64
+ placeholder="2023, 2024",
65
+ info="Comma-separated list of years to filter by (leave empty for all)"
66
+ ),
67
+ ],
68
+ outputs=gr.Textbox(
69
+ label="Retrieved Context",
70
+ lines=10,
71
+ show_copy_button=True
72
+ ),
73
+ title="RAG Retrieval Service UI",
74
+ description="Retrieves semantically similar documents from vector database. Intended for use in RAG pipelines as an MCP server.",
75
+ )
76
+
77
+ # Launch with MCP server enabled
78
+ if __name__ == "__main__":
79
+ ui.launch(
80
+ server_name="0.0.0.0",
81
+ server_port=7861, # Different port from reranker
82
+ mcp_server=True,
83
+ show_error=True
84
+ )
app/retriever.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from qdrant_client.http import models as rest
3
+ from langchain.schema import Document
4
+ from .utils import getconfig
5
+ import logging
6
+
7
+ # Load configuration
8
+ config = getconfig("params.cfg")
9
+
10
+ # Retriever settings from config
11
+ RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
12
+ SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
13
+
14
+ def create_filter(
15
+ reports: List[str] = None,
16
+ sources: str = None,
17
+ subtype: str = None,
18
+ year: List[str] = None
19
+ ) -> Optional[rest.Filter]:
20
+ """
21
+ Create a Qdrant filter based on metadata criteria.
22
+
23
+ Args:
24
+ reports: List of specific report filenames to filter by
25
+ sources: Source type to filter by
26
+ subtype: Document subtype to filter by
27
+ year: List of years to filter by
28
+
29
+ Returns:
30
+ Qdrant Filter object or None if no filters specified
31
+ """
32
+ if not any([reports, sources, subtype, year]):
33
+ return None
34
+
35
+ conditions = []
36
+
37
+ if reports and len(reports) > 0:
38
+ logging.info(f"Defining filter for reports: {reports}")
39
+ conditions.append(
40
+ rest.FieldCondition(
41
+ key="metadata.filename",
42
+ match=rest.MatchAny(any=reports)
43
+ )
44
+ )
45
+ else:
46
+ if sources:
47
+ logging.info(f"Defining filter for sources: {sources}")
48
+ conditions.append(
49
+ rest.FieldCondition(
50
+ key="metadata.source",
51
+ match=rest.MatchValue(value=sources)
52
+ )
53
+ )
54
+
55
+ if subtype:
56
+ logging.info(f"Defining filter for subtype: {subtype}")
57
+ conditions.append(
58
+ rest.FieldCondition(
59
+ key="metadata.subtype",
60
+ match=rest.MatchValue(value=subtype)
61
+ )
62
+ )
63
+
64
+ if year and len(year) > 0:
65
+ logging.info(f"Defining filter for years: {year}")
66
+ conditions.append(
67
+ rest.FieldCondition(
68
+ key="metadata.year",
69
+ match=rest.MatchAny(any=year)
70
+ )
71
+ )
72
+
73
+ if conditions:
74
+ return rest.Filter(must=conditions)
75
+ return None
76
+
77
+ def get_vectorstore():
78
+ """
79
+ Initialize and return the vectorstore connection.
80
+ This function should be implemented based on your specific vectorstore setup.
81
+
82
+ Returns:
83
+ Vectorstore instance (e.g., Qdrant, Pinecone, etc.)
84
+ """
85
+ # TODO: Implement based on your external vector database
86
+ # Example for Qdrant:
87
+ # from langchain_community.vectorstores import Qdrant
88
+ # from qdrant_client import QdrantClient
89
+ #
90
+ # client = QdrantClient(
91
+ # host=config.get("vectorstore", "HOST"),
92
+ # port=config.get("vectorstore", "PORT"),
93
+ # api_key=config.get("vectorstore", "API_KEY", fallback=None)
94
+ # )
95
+ #
96
+ # vectorstore = Qdrant(
97
+ # client=client,
98
+ # collection_name=config.get("vectorstore", "COLLECTION_NAME"),
99
+ # embeddings=your_embedding_model # You'll need to configure this
100
+ # )
101
+ #
102
+ # return vectorstore
103
+
104
+ raise NotImplementedError("Please implement vectorstore connection based on your setup")
105
+
106
+ def retrieve_context(
107
+ query: str,
108
+ reports: List[str] = None,
109
+ sources: str = None,
110
+ subtype: str = None,
111
+ year: List[str] = None,
112
+ top_k: int = None
113
+ ) -> List[Dict[str, Any]]:
114
+ """
115
+ Retrieve semantically similar documents from the vector database.
116
+
117
+ Args:
118
+ query: The search query
119
+ reports: List of specific report filenames to search within
120
+ sources: Source type to filter by
121
+ subtype: Document subtype to filter by
122
+ year: List of years to filter by
123
+ top_k: Number of results to return (defaults to config value)
124
+
125
+ Returns:
126
+ List of dictionaries with 'page_content' and 'metadata' keys
127
+ """
128
+ try:
129
+ # Get vectorstore instance
130
+ vectorstore = get_vectorstore()
131
+
132
+ # Create metadata filter
133
+ filter_obj = create_filter(
134
+ reports=reports or [],
135
+ sources=sources,
136
+ subtype=subtype,
137
+ year=year or []
138
+ )
139
+
140
+ # Set up search parameters
141
+ k = top_k or RETRIEVER_TOP_K
142
+ search_kwargs = {
143
+ "score_threshold": SCORE_THRESHOLD,
144
+ "k": k
145
+ }
146
+
147
+ if filter_obj:
148
+ search_kwargs["filter"] = filter_obj
149
+
150
+ # Create retriever
151
+ retriever = vectorstore.as_retriever(
152
+ search_type="similarity_score_threshold",
153
+ search_kwargs=search_kwargs
154
+ )
155
+
156
+ # Perform retrieval
157
+ retrieved_docs: List[Document] = retriever.invoke(query)
158
+
159
+ logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
160
+
161
+ # Convert to dictionary format
162
+ results = [
163
+ {
164
+ "page_content": doc.page_content,
165
+ "metadata": doc.metadata
166
+ }
167
+ for doc in retrieved_docs
168
+ ]
169
+
170
+ return results
171
+
172
+ except Exception as e:
173
+ logging.error(f"Error during retrieval: {str(e)}")
174
+ raise e
app/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configparser
2
+ import logging
3
+
4
+ def getconfig(configfile_path: str):
5
+ """
6
+ Read the config file
7
+ Params
8
+ ----------------
9
+ configfile_path: file path of .cfg file
10
+ """
11
+ config = configparser.ConfigParser()
12
+ try:
13
+ config.read_file(open(configfile_path))
14
+ return config
15
+ except:
16
+ logging.warning("config file not found")
params.cfg ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [retriever]
2
+ TOP_K = 10
3
+ SCORE_THRESHOLD = 0.6
4
+
5
+ [vectorstore]
6
+ TYPE = qdrant
7
+ HOST = localhost
8
+ PORT = 6333
9
+ COLLECTION_NAME = "auditqa"
10
+ # API_KEY = your_api_key_if_needed
11
+
12
+ [embeddings]
13
+ MODEL_NAME = BAAI/bge-m3
14
+ # DEVICE = cpu
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ langchain
3
+ langchain-community
4
+ qdrant-client
5
+ sentence-transformers