Spaces:
Running
Running
Upload 18 files
Browse files- advanced_rag.py +410 -0
- agent_chat_stream.py +111 -0
- agent_service.py +503 -0
- app.py +31 -404
- batch_index_pdfs.py +151 -0
- cag_service.py +229 -0
- conversation_service.py +308 -0
- embedding_service.py +173 -0
- feedback_tracking_service.py +103 -0
- main.py +1326 -0
- multimodal_pdf_parser.py +390 -0
- pdf_parser.py +371 -0
- prompts/feedback_agent.txt +51 -0
- prompts/sales_agent.txt +47 -0
- qdrant_service.py +446 -0
- requirements.txt +38 -0
- stream_utils.py +86 -0
- tools_service.py +242 -0
advanced_rag.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced RAG techniques for improved retrieval and generation (Best Case 2025)
|
| 3 |
+
Includes: LLM-Based Query Expansion, Cross-Encoder Reranking, Contextual Compression, Hybrid Search
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Optional, Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import re
|
| 10 |
+
from sentence_transformers import CrossEncoder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class RetrievedDocument:
|
| 15 |
+
"""Document retrieved from vector database"""
|
| 16 |
+
id: str
|
| 17 |
+
text: str
|
| 18 |
+
confidence: float
|
| 19 |
+
metadata: Dict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AdvancedRAG:
|
| 23 |
+
"""Advanced RAG system with 2025 best practices"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, embedding_service, qdrant_service):
|
| 26 |
+
self.embedding_service = embedding_service
|
| 27 |
+
self.qdrant_service = qdrant_service
|
| 28 |
+
|
| 29 |
+
# Initialize Cross-Encoder for reranking (multilingual for Vietnamese support)
|
| 30 |
+
print("Loading Cross-Encoder model for reranking...")
|
| 31 |
+
# Use multilingual model instead of English-only ms-marco
|
| 32 |
+
self.cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
|
| 33 |
+
print("✓ Cross-Encoder loaded (multilingual)")
|
| 34 |
+
|
| 35 |
+
def expand_query_llm(
|
| 36 |
+
self,
|
| 37 |
+
query: str,
|
| 38 |
+
hf_client=None
|
| 39 |
+
) -> List[str]:
|
| 40 |
+
"""
|
| 41 |
+
Expand query using LLM (Best Case 2025)
|
| 42 |
+
Generates query variations, sub-queries, and hypothetical answers
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
query: Original user query
|
| 46 |
+
hf_client: HuggingFace InferenceClient (optional)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of expanded queries
|
| 50 |
+
"""
|
| 51 |
+
queries = [query]
|
| 52 |
+
|
| 53 |
+
# Fallback to rule-based if no LLM client
|
| 54 |
+
if not hf_client:
|
| 55 |
+
return self._expand_query_rule_based(query)
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# LLM-based expansion prompt
|
| 59 |
+
expansion_prompt = f"""Given this user question, generate 2-3 alternative phrasings or sub-questions that would help retrieve relevant information.
|
| 60 |
+
|
| 61 |
+
User Question: {query}
|
| 62 |
+
|
| 63 |
+
Alternative queries (one per line):"""
|
| 64 |
+
|
| 65 |
+
# Generate expansions
|
| 66 |
+
response = ""
|
| 67 |
+
for msg in hf_client.chat_completion(
|
| 68 |
+
messages=[{"role": "user", "content": expansion_prompt}],
|
| 69 |
+
max_tokens=256,
|
| 70 |
+
stream=True,
|
| 71 |
+
temperature=0.7,
|
| 72 |
+
model="openai/gpt-oss-20b"
|
| 73 |
+
):
|
| 74 |
+
if msg.choices and msg.choices[0].delta.content:
|
| 75 |
+
response += msg.choices[0].delta.content
|
| 76 |
+
|
| 77 |
+
# Parse expansions
|
| 78 |
+
lines = [line.strip() for line in response.split('\n') if line.strip()]
|
| 79 |
+
# Filter out numbered lists, dashes, etc.
|
| 80 |
+
clean_lines = []
|
| 81 |
+
for line in lines:
|
| 82 |
+
# Remove common list markers
|
| 83 |
+
cleaned = re.sub(r'^[\d\-\*\•]+[\.\)]\s*', '', line)
|
| 84 |
+
if cleaned and len(cleaned) > 5:
|
| 85 |
+
clean_lines.append(cleaned)
|
| 86 |
+
|
| 87 |
+
queries.extend(clean_lines[:3]) # Add top 3 expansions
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"LLM expansion failed, using rule-based: {e}")
|
| 91 |
+
return self._expand_query_rule_based(query)
|
| 92 |
+
|
| 93 |
+
return queries[:4] # Original + 3 expansions
|
| 94 |
+
|
| 95 |
+
def _expand_query_rule_based(self, query: str) -> List[str]:
|
| 96 |
+
"""
|
| 97 |
+
Fallback rule-based query expansion
|
| 98 |
+
Simple but effective Vietnamese-aware expansion
|
| 99 |
+
"""
|
| 100 |
+
queries = [query]
|
| 101 |
+
|
| 102 |
+
# Vietnamese question words
|
| 103 |
+
question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào',
|
| 104 |
+
'sao', 'tại sao', 'có', 'là', 'được', 'không', 'làm sao']
|
| 105 |
+
|
| 106 |
+
query_lower = query.lower()
|
| 107 |
+
for qw in question_words:
|
| 108 |
+
if qw in query_lower:
|
| 109 |
+
variant = query_lower.replace(qw, '').strip()
|
| 110 |
+
if variant and variant != query_lower:
|
| 111 |
+
queries.append(variant)
|
| 112 |
+
break # One variation is enough
|
| 113 |
+
|
| 114 |
+
# Extract key phrases
|
| 115 |
+
words = query.split()
|
| 116 |
+
if len(words) > 3:
|
| 117 |
+
key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3])
|
| 118 |
+
if key_phrases not in queries:
|
| 119 |
+
queries.append(key_phrases)
|
| 120 |
+
|
| 121 |
+
return queries[:3]
|
| 122 |
+
|
| 123 |
+
def multi_query_retrieval(
|
| 124 |
+
self,
|
| 125 |
+
query: str,
|
| 126 |
+
top_k: int = 5,
|
| 127 |
+
score_threshold: float = 0.5,
|
| 128 |
+
expanded_queries: Optional[List[str]] = None
|
| 129 |
+
) -> List[RetrievedDocument]:
|
| 130 |
+
"""
|
| 131 |
+
Retrieve documents using multiple query variations
|
| 132 |
+
Combines results from all query variations with deduplication
|
| 133 |
+
"""
|
| 134 |
+
if expanded_queries is None:
|
| 135 |
+
expanded_queries = [query]
|
| 136 |
+
|
| 137 |
+
all_results = {} # Deduplicate by doc_id
|
| 138 |
+
|
| 139 |
+
for q in expanded_queries:
|
| 140 |
+
# Generate embedding for each query variant
|
| 141 |
+
query_embedding = self.embedding_service.encode_text(q)
|
| 142 |
+
|
| 143 |
+
# Search in Qdrant
|
| 144 |
+
results = self.qdrant_service.search(
|
| 145 |
+
query_embedding=query_embedding,
|
| 146 |
+
limit=top_k,
|
| 147 |
+
score_threshold=score_threshold
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Add to results (keep highest score for duplicates)
|
| 151 |
+
for result in results:
|
| 152 |
+
doc_id = result["id"]
|
| 153 |
+
if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
|
| 154 |
+
# Lấy text từ metadata - hỗ trợ cả "text" (string) và "texts" (array)
|
| 155 |
+
metadata = result["metadata"]
|
| 156 |
+
doc_text = metadata.get("text", "")
|
| 157 |
+
if not doc_text and "texts" in metadata:
|
| 158 |
+
# Nếu là array, join thành string
|
| 159 |
+
texts_arr = metadata.get("texts", [])
|
| 160 |
+
if isinstance(texts_arr, list):
|
| 161 |
+
doc_text = "\n".join(texts_arr)
|
| 162 |
+
else:
|
| 163 |
+
doc_text = str(texts_arr)
|
| 164 |
+
|
| 165 |
+
all_results[doc_id] = RetrievedDocument(
|
| 166 |
+
id=doc_id,
|
| 167 |
+
text=doc_text,
|
| 168 |
+
confidence=result["confidence"],
|
| 169 |
+
metadata=metadata
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Sort by confidence and return top_k
|
| 173 |
+
sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True)
|
| 174 |
+
return sorted_results[:top_k * 2] # Return more for reranking
|
| 175 |
+
|
| 176 |
+
def rerank_documents_cross_encoder(
|
| 177 |
+
self,
|
| 178 |
+
query: str,
|
| 179 |
+
documents: List[RetrievedDocument],
|
| 180 |
+
top_k: int = 5
|
| 181 |
+
) -> List[RetrievedDocument]:
|
| 182 |
+
"""
|
| 183 |
+
Rerank documents using Cross-Encoder (Best Case 2025)
|
| 184 |
+
Cross-Encoder provides superior relevance scoring compared to bi-encoders
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
query: Original user query
|
| 188 |
+
documents: Retrieved documents to rerank
|
| 189 |
+
top_k: Number of top documents to return
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Reranked documents
|
| 193 |
+
"""
|
| 194 |
+
if not documents:
|
| 195 |
+
return documents
|
| 196 |
+
|
| 197 |
+
# Prepare query-document pairs for Cross-Encoder
|
| 198 |
+
pairs = [[query, doc.text] for doc in documents]
|
| 199 |
+
|
| 200 |
+
# Get Cross-Encoder scores (raw logits)
|
| 201 |
+
ce_scores = self.cross_encoder.predict(pairs)
|
| 202 |
+
ce_scores = [float(s) for s in ce_scores]
|
| 203 |
+
|
| 204 |
+
# Min-Max normalization để scale về 0-1
|
| 205 |
+
# Thay vì sigmoid (cho điểm rất thấp với logits âm)
|
| 206 |
+
min_score = min(ce_scores)
|
| 207 |
+
max_score = max(ce_scores)
|
| 208 |
+
|
| 209 |
+
if max_score - min_score > 0.001: # Có sự khác biệt giữa các scores
|
| 210 |
+
ce_scores_normalized = [
|
| 211 |
+
(score - min_score) / (max_score - min_score)
|
| 212 |
+
for score in ce_scores
|
| 213 |
+
]
|
| 214 |
+
else:
|
| 215 |
+
# Tất cả scores gần như bằng nhau -> giữ original confidence
|
| 216 |
+
ce_scores_normalized = [doc.confidence for doc in documents]
|
| 217 |
+
|
| 218 |
+
# Combine: 70% Cross-Encoder ranking + 30% original cosine similarity
|
| 219 |
+
# Để giữ lại một phần semantic similarity từ embedding
|
| 220 |
+
reranked = []
|
| 221 |
+
for doc, ce_norm in zip(documents, ce_scores_normalized):
|
| 222 |
+
combined_score = 0.7 * ce_norm + 0.3 * doc.confidence
|
| 223 |
+
reranked.append(RetrievedDocument(
|
| 224 |
+
id=doc.id,
|
| 225 |
+
text=doc.text,
|
| 226 |
+
confidence=float(combined_score),
|
| 227 |
+
metadata=doc.metadata
|
| 228 |
+
))
|
| 229 |
+
|
| 230 |
+
# Sort by combined score
|
| 231 |
+
reranked.sort(key=lambda x: x.confidence, reverse=True)
|
| 232 |
+
return reranked[:top_k]
|
| 233 |
+
|
| 234 |
+
def compress_context(
|
| 235 |
+
self,
|
| 236 |
+
query: str,
|
| 237 |
+
documents: List[RetrievedDocument],
|
| 238 |
+
max_tokens: int = 500
|
| 239 |
+
) -> List[RetrievedDocument]:
|
| 240 |
+
"""
|
| 241 |
+
Compress context - giữ nguyên nội dung quan trọng, chỉ truncate nếu quá dài
|
| 242 |
+
KHÔNG dùng word overlap vì nó loại bỏ sai thông tin quan trọng
|
| 243 |
+
"""
|
| 244 |
+
compressed_docs = []
|
| 245 |
+
|
| 246 |
+
for doc in documents:
|
| 247 |
+
text = doc.text.strip()
|
| 248 |
+
|
| 249 |
+
# Chỉ truncate nếu text quá dài (ước tính ~4 chars/token)
|
| 250 |
+
max_chars = max_tokens * 4
|
| 251 |
+
if len(text) > max_chars:
|
| 252 |
+
# Cắt thông minh tại câu gần nhất
|
| 253 |
+
truncated = text[:max_chars]
|
| 254 |
+
last_period = max(
|
| 255 |
+
truncated.rfind('.'),
|
| 256 |
+
truncated.rfind('!'),
|
| 257 |
+
truncated.rfind('?'),
|
| 258 |
+
truncated.rfind('\n')
|
| 259 |
+
)
|
| 260 |
+
if last_period > max_chars * 0.5: # Nếu tìm thấy dấu câu ở nửa sau
|
| 261 |
+
truncated = truncated[:last_period + 1]
|
| 262 |
+
text = truncated.strip()
|
| 263 |
+
|
| 264 |
+
compressed_docs.append(RetrievedDocument(
|
| 265 |
+
id=doc.id,
|
| 266 |
+
text=text,
|
| 267 |
+
confidence=doc.confidence,
|
| 268 |
+
metadata=doc.metadata
|
| 269 |
+
))
|
| 270 |
+
|
| 271 |
+
return compressed_docs
|
| 272 |
+
|
| 273 |
+
def _split_sentences(self, text: str) -> List[str]:
|
| 274 |
+
"""Split text into sentences (Vietnamese-aware)"""
|
| 275 |
+
sentences = re.split(r'[.!?]+', text)
|
| 276 |
+
return [s.strip() for s in sentences if s.strip()]
|
| 277 |
+
|
| 278 |
+
def hybrid_rag_pipeline(
|
| 279 |
+
self,
|
| 280 |
+
query: str,
|
| 281 |
+
top_k: int = 5,
|
| 282 |
+
score_threshold: float = 0.5,
|
| 283 |
+
use_reranking: bool = True,
|
| 284 |
+
use_compression: bool = True,
|
| 285 |
+
use_query_expansion: bool = True,
|
| 286 |
+
max_context_tokens: int = 500,
|
| 287 |
+
hf_client=None
|
| 288 |
+
) -> Tuple[List[RetrievedDocument], Dict]:
|
| 289 |
+
"""
|
| 290 |
+
Complete advanced RAG pipeline (Best Case 2025)
|
| 291 |
+
1. LLM-based query expansion
|
| 292 |
+
2. Multi-query retrieval
|
| 293 |
+
3. Cross-Encoder reranking
|
| 294 |
+
4. Contextual compression
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
query: User query
|
| 298 |
+
top_k: Number of documents to return
|
| 299 |
+
score_threshold: Minimum relevance score
|
| 300 |
+
use_reranking: Enable Cross-Encoder reranking
|
| 301 |
+
use_compression: Enable context compression
|
| 302 |
+
use_query_expansion: Enable LLM-based query expansion
|
| 303 |
+
max_context_tokens: Max tokens for compression
|
| 304 |
+
hf_client: HuggingFace InferenceClient for expansion
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
(documents, stats)
|
| 308 |
+
"""
|
| 309 |
+
stats = {
|
| 310 |
+
"original_query": query,
|
| 311 |
+
"expanded_queries": [],
|
| 312 |
+
"initial_results": 0,
|
| 313 |
+
"after_rerank": 0,
|
| 314 |
+
"after_compression": 0,
|
| 315 |
+
"used_cross_encoder": use_reranking,
|
| 316 |
+
"used_llm_expansion": use_query_expansion and hf_client is not None
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
# Step 1: Query Expansion (LLM-based or rule-based)
|
| 320 |
+
if use_query_expansion:
|
| 321 |
+
expanded_queries = self.expand_query_llm(query, hf_client)
|
| 322 |
+
else:
|
| 323 |
+
expanded_queries = [query]
|
| 324 |
+
|
| 325 |
+
stats["expanded_queries"] = expanded_queries
|
| 326 |
+
|
| 327 |
+
# Step 2: Multi-query retrieval
|
| 328 |
+
documents = self.multi_query_retrieval(
|
| 329 |
+
query=query,
|
| 330 |
+
top_k=top_k * 2, # Get more candidates for reranking
|
| 331 |
+
score_threshold=score_threshold,
|
| 332 |
+
expanded_queries=expanded_queries
|
| 333 |
+
)
|
| 334 |
+
stats["initial_results"] = len(documents)
|
| 335 |
+
|
| 336 |
+
# Step 3: Cross-Encoder Reranking (Best Case 2025)
|
| 337 |
+
if use_reranking and documents:
|
| 338 |
+
documents = self.rerank_documents_cross_encoder(
|
| 339 |
+
query=query,
|
| 340 |
+
documents=documents,
|
| 341 |
+
top_k=top_k
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
documents = documents[:top_k]
|
| 345 |
+
stats["after_rerank"] = len(documents)
|
| 346 |
+
|
| 347 |
+
# Step 4: Contextual compression (optional)
|
| 348 |
+
if use_compression and documents:
|
| 349 |
+
documents = self.compress_context(
|
| 350 |
+
query=query,
|
| 351 |
+
documents=documents,
|
| 352 |
+
max_tokens=max_context_tokens
|
| 353 |
+
)
|
| 354 |
+
stats["after_compression"] = len(documents)
|
| 355 |
+
|
| 356 |
+
return documents, stats
|
| 357 |
+
|
| 358 |
+
def format_context_for_llm(
|
| 359 |
+
self,
|
| 360 |
+
documents: List[RetrievedDocument],
|
| 361 |
+
include_metadata: bool = True
|
| 362 |
+
) -> str:
|
| 363 |
+
"""
|
| 364 |
+
Format retrieved documents into context string for LLM
|
| 365 |
+
Uses better structure for improved LLM understanding
|
| 366 |
+
"""
|
| 367 |
+
if not documents:
|
| 368 |
+
return ""
|
| 369 |
+
|
| 370 |
+
context_parts = ["RELEVANT CONTEXT:\n"]
|
| 371 |
+
|
| 372 |
+
for i, doc in enumerate(documents, 1):
|
| 373 |
+
context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---")
|
| 374 |
+
context_parts.append(doc.text)
|
| 375 |
+
|
| 376 |
+
if include_metadata and doc.metadata:
|
| 377 |
+
# Add useful metadata
|
| 378 |
+
meta_str = []
|
| 379 |
+
for key, value in doc.metadata.items():
|
| 380 |
+
if key not in ['text', 'texts'] and value:
|
| 381 |
+
meta_str.append(f"{key}: {value}")
|
| 382 |
+
if meta_str:
|
| 383 |
+
context_parts.append(f"[Metadata: {', '.join(meta_str)}]")
|
| 384 |
+
|
| 385 |
+
context_parts.append("\n--- End of Context ---\n")
|
| 386 |
+
return "\n".join(context_parts)
|
| 387 |
+
|
| 388 |
+
def build_rag_prompt(
|
| 389 |
+
self,
|
| 390 |
+
query: str,
|
| 391 |
+
context: str,
|
| 392 |
+
system_message: str = "You are a helpful AI assistant."
|
| 393 |
+
) -> str:
|
| 394 |
+
"""
|
| 395 |
+
Build optimized RAG system prompt for LLM
|
| 396 |
+
Query sẽ được gửi riêng trong user message
|
| 397 |
+
"""
|
| 398 |
+
prompt_template = f"""{system_message}
|
| 399 |
+
|
| 400 |
+
{context}
|
| 401 |
+
|
| 402 |
+
HƯỚNG DẪN TRẢ LỜI:
|
| 403 |
+
1. Đóng vai trò là một trợ lý ảo thân thiện, trả lời tự nhiên bằng tiếng Việt.
|
| 404 |
+
2. Dựa vào CONTEXT được cung cấp để trả lời câu hỏi.
|
| 405 |
+
3. KHÔNG copy nguyên văn text từ context. Hãy tổng hợp lại thông tin một cách mạch lạc.
|
| 406 |
+
4. Bắt đầu câu trả lời bằng các cụm từ tự nhiên như: "Dựa trên dữ liệu tôi tìm thấy...", "Tôi có thông tin về các sự kiện sau...", "Có vẻ như đây là những gì bạn ��ang tìm...".
|
| 407 |
+
5. Nếu có nhiều kết quả, hãy liệt kê ngắn gọn các điểm chính (Tên, Thời gian, Địa điểm).
|
| 408 |
+
6. Nếu context không liên quan, hãy lịch sự nói rằng bạn chưa tìm thấy thông tin phù hợp trong hệ thống."""
|
| 409 |
+
|
| 410 |
+
return prompt_template
|
agent_chat_stream.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Chat Streaming Endpoint
|
| 3 |
+
SSE-based real-time streaming for Sales & Feedback agents
|
| 4 |
+
"""
|
| 5 |
+
from typing import AsyncGenerator
|
| 6 |
+
from stream_utils import format_sse, EVENT_STATUS, EVENT_TOKEN, EVENT_DONE, EVENT_ERROR, EVENT_METADATA
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
async def agent_chat_stream(
|
| 11 |
+
request,
|
| 12 |
+
agent_service,
|
| 13 |
+
conversation_service
|
| 14 |
+
) -> AsyncGenerator[str, None]:
|
| 15 |
+
"""
|
| 16 |
+
Stream agent responses in real-time (SSE format)
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
request: ChatRequest with message, session_id, mode, user_id
|
| 20 |
+
agent_service: AgentService instance
|
| 21 |
+
conversation_service: ConversationService instance
|
| 22 |
+
|
| 23 |
+
Yields SSE events:
|
| 24 |
+
- status: Processing updates
|
| 25 |
+
- token: Text chunks
|
| 26 |
+
- metadata: Session info
|
| 27 |
+
- done: Completion signal
|
| 28 |
+
- error: Error messages
|
| 29 |
+
"""
|
| 30 |
+
try:
|
| 31 |
+
# === SESSION MANAGEMENT ===
|
| 32 |
+
session_id = request.session_id
|
| 33 |
+
if not session_id:
|
| 34 |
+
session_id = conversation_service.create_session(
|
| 35 |
+
metadata={"user_agent": "api", "created_via": "agent_stream"},
|
| 36 |
+
user_id=request.user_id
|
| 37 |
+
)
|
| 38 |
+
yield format_sse(EVENT_METADATA, {"session_id": session_id})
|
| 39 |
+
|
| 40 |
+
# Get conversation history
|
| 41 |
+
history = conversation_service.get_conversation_history(session_id)
|
| 42 |
+
|
| 43 |
+
# Convert to messages format
|
| 44 |
+
messages = []
|
| 45 |
+
for h in history:
|
| 46 |
+
messages.append({"role": h["role"], "content": h["content"]})
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Determine mode
|
| 50 |
+
mode = getattr(request, 'mode', 'sales') # Default to sales
|
| 51 |
+
user_id = getattr(request, 'user_id', None)
|
| 52 |
+
access_token = getattr(request, 'access_token', None)
|
| 53 |
+
|
| 54 |
+
# Debug logging
|
| 55 |
+
print(f"📋 Request Info:")
|
| 56 |
+
print(f" - Mode: {mode}")
|
| 57 |
+
print(f" - User ID: {user_id}")
|
| 58 |
+
print(f" - Access Token: {'✅ Present' if access_token else '❌ Missing'}")
|
| 59 |
+
if access_token:
|
| 60 |
+
print(f" - Token preview: {access_token[:20]}...")
|
| 61 |
+
|
| 62 |
+
# === STATUS UPDATE ===
|
| 63 |
+
if mode == 'feedback':
|
| 64 |
+
yield format_sse(EVENT_STATUS, "Đang kiểm tra lịch sử sự kiện của bạn...")
|
| 65 |
+
else:
|
| 66 |
+
yield format_sse(EVENT_STATUS, "Đang tư vấn...")
|
| 67 |
+
|
| 68 |
+
# === CALL AGENT ===
|
| 69 |
+
result = await agent_service.chat(
|
| 70 |
+
user_message=request.message,
|
| 71 |
+
conversation_history=messages,
|
| 72 |
+
mode=mode,
|
| 73 |
+
user_id=user_id,
|
| 74 |
+
access_token=access_token
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
agent_response = result["message"]
|
| 78 |
+
|
| 79 |
+
# === STREAM RESPONSE TOKEN BY TOKEN ===
|
| 80 |
+
# Simple character-by-character streaming
|
| 81 |
+
chunk_size = 5 # Characters per chunk
|
| 82 |
+
for i in range(0, len(agent_response), chunk_size):
|
| 83 |
+
chunk = agent_response[i:i+chunk_size]
|
| 84 |
+
yield format_sse(EVENT_TOKEN, chunk)
|
| 85 |
+
# Small delay for smoother streaming
|
| 86 |
+
import asyncio
|
| 87 |
+
await asyncio.sleep(0.02)
|
| 88 |
+
|
| 89 |
+
# === SAVE HISTORY ===
|
| 90 |
+
conversation_service.add_message(
|
| 91 |
+
session_id=session_id,
|
| 92 |
+
role="user",
|
| 93 |
+
content=request.message
|
| 94 |
+
)
|
| 95 |
+
conversation_service.add_message(
|
| 96 |
+
session_id=session_id,
|
| 97 |
+
role="assistant",
|
| 98 |
+
content=agent_response
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# === DONE ===
|
| 102 |
+
yield format_sse(EVENT_DONE, {
|
| 103 |
+
"session_id": session_id,
|
| 104 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 105 |
+
"mode": mode,
|
| 106 |
+
"tool_calls": len(result.get("tool_calls", []))
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"⚠️ Agent Stream Error: {e}")
|
| 111 |
+
yield format_sse(EVENT_ERROR, str(e))
|
agent_service.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent Service - Central Brain for Sales & Feedback Agents
|
| 3 |
+
Manages LLM conversation loop with tool calling
|
| 4 |
+
"""
|
| 5 |
+
from typing import Dict, Any, List, Optional
|
| 6 |
+
import os
|
| 7 |
+
from tools_service import ToolsService
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AgentService:
|
| 11 |
+
"""
|
| 12 |
+
Manages the conversation loop between User -> LLM -> Tools -> Response
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
tools_service: ToolsService,
|
| 18 |
+
embedding_service,
|
| 19 |
+
qdrant_service,
|
| 20 |
+
advanced_rag,
|
| 21 |
+
hf_token: str,
|
| 22 |
+
feedback_tracking=None # NEW: Optional feedback tracking
|
| 23 |
+
):
|
| 24 |
+
self.tools_service = tools_service
|
| 25 |
+
self.embedding_service = embedding_service
|
| 26 |
+
self.qdrant_service = qdrant_service
|
| 27 |
+
self.advanced_rag = advanced_rag
|
| 28 |
+
self.hf_token = hf_token
|
| 29 |
+
self.feedback_tracking = feedback_tracking
|
| 30 |
+
|
| 31 |
+
# Load system prompts
|
| 32 |
+
self.prompts = self._load_prompts()
|
| 33 |
+
|
| 34 |
+
def _load_prompts(self) -> Dict[str, str]:
|
| 35 |
+
"""Load system prompts from files"""
|
| 36 |
+
prompts = {}
|
| 37 |
+
prompts_dir = "prompts"
|
| 38 |
+
|
| 39 |
+
for mode in ["sales_agent", "feedback_agent"]:
|
| 40 |
+
filepath = os.path.join(prompts_dir, f"{mode}.txt")
|
| 41 |
+
try:
|
| 42 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 43 |
+
prompts[mode] = f.read()
|
| 44 |
+
print(f"✓ Loaded prompt: {mode}")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"⚠️ Error loading {mode} prompt: {e}")
|
| 47 |
+
prompts[mode] = ""
|
| 48 |
+
|
| 49 |
+
return prompts
|
| 50 |
+
|
| 51 |
+
async def chat(
|
| 52 |
+
self,
|
| 53 |
+
user_message: str,
|
| 54 |
+
conversation_history: List[Dict],
|
| 55 |
+
mode: str = "sales", # "sales" or "feedback"
|
| 56 |
+
user_id: Optional[str] = None,
|
| 57 |
+
access_token: Optional[str] = None, # NEW: For authenticated API calls
|
| 58 |
+
max_iterations: int = 3
|
| 59 |
+
) -> Dict[str, Any]:
|
| 60 |
+
"""
|
| 61 |
+
Main conversation loop
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
user_message: User's input
|
| 65 |
+
conversation_history: Previous messages [{"role": "user", "content": ...}, ...]
|
| 66 |
+
mode: "sales" or "feedback"
|
| 67 |
+
user_id: User ID (for feedback mode to check purchase history)
|
| 68 |
+
access_token: JWT token for authenticated API calls
|
| 69 |
+
max_iterations: Maximum tool call iterations to prevent infinite loops
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
{
|
| 73 |
+
"message": "Bot response",
|
| 74 |
+
"tool_calls": [...], # List of tools called (for debugging)
|
| 75 |
+
"mode": mode
|
| 76 |
+
}
|
| 77 |
+
"""
|
| 78 |
+
print(f"\n🤖 Agent Mode: {mode}")
|
| 79 |
+
print(f"👤 User Message: {user_message}")
|
| 80 |
+
print(f"🔑 Auth Info:")
|
| 81 |
+
print(f" - User ID: {user_id}")
|
| 82 |
+
print(f" - Access Token: {'✅ Received' if access_token else '❌ None'}")
|
| 83 |
+
|
| 84 |
+
# Store user_id and access_token for tool calls
|
| 85 |
+
self.current_user_id = user_id
|
| 86 |
+
self.current_access_token = access_token
|
| 87 |
+
if access_token:
|
| 88 |
+
print(f" - Stored access_token for tools: {access_token[:20]}...")
|
| 89 |
+
if user_id:
|
| 90 |
+
print(f" - Stored user_id for tools: {user_id}")
|
| 91 |
+
|
| 92 |
+
# Select system prompt
|
| 93 |
+
system_prompt = self._get_system_prompt(mode)
|
| 94 |
+
|
| 95 |
+
# Build conversation context
|
| 96 |
+
messages = self._build_messages(system_prompt, conversation_history, user_message)
|
| 97 |
+
|
| 98 |
+
# Agentic loop: LLM may call tools multiple times
|
| 99 |
+
tool_calls_made = []
|
| 100 |
+
current_response = None
|
| 101 |
+
|
| 102 |
+
for iteration in range(max_iterations):
|
| 103 |
+
print(f"\n🔄 Iteration {iteration + 1}")
|
| 104 |
+
|
| 105 |
+
# Call LLM
|
| 106 |
+
llm_response = await self._call_llm(messages)
|
| 107 |
+
print(f"🧠 LLM Response: {llm_response[:200]}...")
|
| 108 |
+
|
| 109 |
+
# Check if LLM wants to call a tool
|
| 110 |
+
tool_call = self._parse_tool_call(llm_response)
|
| 111 |
+
|
| 112 |
+
if not tool_call:
|
| 113 |
+
# No tool call -> This is the final response
|
| 114 |
+
current_response = llm_response
|
| 115 |
+
break
|
| 116 |
+
|
| 117 |
+
# Execute tool
|
| 118 |
+
print(f"🔧 Tool Called: {tool_call['tool_name']}")
|
| 119 |
+
|
| 120 |
+
# Auto-inject real user_id for get_purchased_events
|
| 121 |
+
if tool_call['tool_name'] == 'get_purchased_events' and self.current_user_id:
|
| 122 |
+
print(f"🔄 Auto-injecting real user_id: {self.current_user_id}")
|
| 123 |
+
tool_call['arguments']['user_id'] = self.current_user_id
|
| 124 |
+
|
| 125 |
+
tool_result = await self.tools_service.execute_tool(
|
| 126 |
+
tool_call['tool_name'],
|
| 127 |
+
tool_call['arguments'],
|
| 128 |
+
access_token=self.current_access_token # Pass access_token
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Record tool call
|
| 132 |
+
tool_calls_made.append({
|
| 133 |
+
"function": tool_call['tool_name'],
|
| 134 |
+
"arguments": tool_call['arguments'],
|
| 135 |
+
"result": tool_result
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
# Add tool result to conversation
|
| 139 |
+
messages.append({
|
| 140 |
+
"role": "assistant",
|
| 141 |
+
"content": llm_response
|
| 142 |
+
})
|
| 143 |
+
messages.append({
|
| 144 |
+
"role": "system",
|
| 145 |
+
"content": f"Tool Result:\n{self._format_tool_result({'result': tool_result})}"
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
# If tool returns "run_rag_search", handle it specially
|
| 149 |
+
if isinstance(tool_result, dict) and tool_result.get("action") == "run_rag_search":
|
| 150 |
+
rag_results = await self._execute_rag_search(tool_result["query"])
|
| 151 |
+
messages[-1]["content"] = f"RAG Search Results:\n{rag_results}"
|
| 152 |
+
|
| 153 |
+
# Clean up response
|
| 154 |
+
final_response = current_response or llm_response
|
| 155 |
+
final_response = self._clean_response(final_response)
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"message": final_response,
|
| 159 |
+
"tool_calls": tool_calls_made,
|
| 160 |
+
"mode": mode
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def _get_system_prompt(self, mode: str) -> str:
|
| 164 |
+
"""Get system prompt for selected mode with tools definition"""
|
| 165 |
+
prompt_key = f"{mode}_agent" if mode in ["sales", "feedback"] else "sales_agent"
|
| 166 |
+
base_prompt = self.prompts.get(prompt_key, "")
|
| 167 |
+
|
| 168 |
+
# Add tools definition
|
| 169 |
+
tools_definition = self._get_tools_definition()
|
| 170 |
+
|
| 171 |
+
return f"{base_prompt}\n\n{tools_definition}"
|
| 172 |
+
|
| 173 |
+
def _get_tools_definition(self) -> str:
|
| 174 |
+
"""Get tools definition in text format for prompt"""
|
| 175 |
+
return """
|
| 176 |
+
# AVAILABLE TOOLS
|
| 177 |
+
|
| 178 |
+
You can call the following tools when needed. To call a tool, output a JSON block like this:
|
| 179 |
+
|
| 180 |
+
```json
|
| 181 |
+
{
|
| 182 |
+
"tool_call": "tool_name",
|
| 183 |
+
"arguments": {
|
| 184 |
+
"arg1": "value1",
|
| 185 |
+
"arg2": "value2"
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## Tools List:
|
| 191 |
+
|
| 192 |
+
### 1. search_events
|
| 193 |
+
Search for events matching user criteria.
|
| 194 |
+
Arguments:
|
| 195 |
+
- query (string): Search keywords
|
| 196 |
+
- vibe (string, optional): Mood/vibe (e.g., "chill", "sôi động")
|
| 197 |
+
- time (string, optional): Time period (e.g., "cuối tuần này")
|
| 198 |
+
|
| 199 |
+
Example:
|
| 200 |
+
```json
|
| 201 |
+
{"tool_call": "search_events", "arguments": {"query": "nhạc rock", "vibe": "sôi động"}}
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
### 2. get_event_details
|
| 205 |
+
Get detailed information about a specific event.
|
| 206 |
+
Arguments:
|
| 207 |
+
- event_id (string): Event ID from search results
|
| 208 |
+
|
| 209 |
+
Example:
|
| 210 |
+
```json
|
| 211 |
+
{"tool_call": "get_event_details", "arguments": {"event_id": "6900ae38eb03f29702c7fd1d"}}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
### 3. get_purchased_events (Feedback mode only)
|
| 215 |
+
Check which events the user has attended.
|
| 216 |
+
Arguments:
|
| 217 |
+
- user_id (string): User ID
|
| 218 |
+
|
| 219 |
+
Example:
|
| 220 |
+
```json
|
| 221 |
+
{"tool_call": "get_purchased_events", "arguments": {"user_id": "user_123"}}
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
### 4. save_feedback
|
| 225 |
+
Save user's feedback/review for an event.
|
| 226 |
+
Arguments:
|
| 227 |
+
- event_id (string): Event ID
|
| 228 |
+
- rating (integer): 1-5 stars
|
| 229 |
+
- comment (string, optional): User's comment
|
| 230 |
+
|
| 231 |
+
Example:
|
| 232 |
+
```json
|
| 233 |
+
{"tool_call": "save_feedback", "arguments": {"event_id": "abc123", "rating": 5, "comment": "Tuyệt vời!"}}
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
### 5. save_lead
|
| 237 |
+
Save customer contact information.
|
| 238 |
+
Arguments:
|
| 239 |
+
- email (string, optional): Email address
|
| 240 |
+
- phone (string, optional): Phone number
|
| 241 |
+
- interest (string, optional): What they're interested in
|
| 242 |
+
|
| 243 |
+
Example:
|
| 244 |
+
```json
|
| 245 |
+
{"tool_call": "save_lead", "arguments": {"email": "user@example.com", "interest": "Rock show"}}
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**IMPORTANT:**
|
| 249 |
+
- Call tools ONLY when you need real-time data
|
| 250 |
+
- After receiving tool results, respond naturally to the user
|
| 251 |
+
- Don't expose raw JSON to users - always format nicely
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
def _build_messages(
|
| 255 |
+
self,
|
| 256 |
+
system_prompt: str,
|
| 257 |
+
history: List[Dict],
|
| 258 |
+
user_message: str
|
| 259 |
+
) -> List[Dict]:
|
| 260 |
+
"""Build messages array for LLM"""
|
| 261 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 262 |
+
|
| 263 |
+
# Add conversation history
|
| 264 |
+
messages.extend(history)
|
| 265 |
+
|
| 266 |
+
# Add current user message
|
| 267 |
+
messages.append({"role": "user", "content": user_message})
|
| 268 |
+
|
| 269 |
+
return messages
|
| 270 |
+
|
| 271 |
+
async def _call_llm(self, messages: List[Dict]) -> str:
|
| 272 |
+
"""
|
| 273 |
+
Call HuggingFace LLM directly using chat_completion (conversational)
|
| 274 |
+
"""
|
| 275 |
+
try:
|
| 276 |
+
from huggingface_hub import AsyncInferenceClient
|
| 277 |
+
|
| 278 |
+
# Create async client
|
| 279 |
+
client = AsyncInferenceClient(token=self.hf_token)
|
| 280 |
+
|
| 281 |
+
# Call HF API with chat completion (conversational)
|
| 282 |
+
response_text = ""
|
| 283 |
+
async for message in await client.chat_completion(
|
| 284 |
+
messages=messages, # Use messages directly
|
| 285 |
+
model="openai/gpt-oss-20b", # GPT-OSS 20B
|
| 286 |
+
max_tokens=512,
|
| 287 |
+
temperature=0.7,
|
| 288 |
+
stream=True
|
| 289 |
+
):
|
| 290 |
+
if message.choices and message.choices[0].delta.content:
|
| 291 |
+
response_text += message.choices[0].delta.content
|
| 292 |
+
|
| 293 |
+
return response_text
|
| 294 |
+
except Exception as e:
|
| 295 |
+
print(f"⚠️ LLM Call Error: {e}")
|
| 296 |
+
return "Xin lỗi, tôi đang gặp chút vấn đề kỹ thuật. Bạn thử lại sau nhé!"
|
| 297 |
+
|
| 298 |
+
def _messages_to_prompt(self, messages: List[Dict]) -> str:
|
| 299 |
+
"""Convert messages array to single prompt string"""
|
| 300 |
+
prompt_parts = []
|
| 301 |
+
|
| 302 |
+
for msg in messages:
|
| 303 |
+
role = msg["role"]
|
| 304 |
+
content = msg["content"]
|
| 305 |
+
|
| 306 |
+
if role == "system":
|
| 307 |
+
prompt_parts.append(f"[SYSTEM]\n{content}\n")
|
| 308 |
+
elif role == "user":
|
| 309 |
+
prompt_parts.append(f"[USER]\n{content}\n")
|
| 310 |
+
elif role == "assistant":
|
| 311 |
+
prompt_parts.append(f"[ASSISTANT]\n{content}\n")
|
| 312 |
+
|
| 313 |
+
return "\n".join(prompt_parts)
|
| 314 |
+
|
| 315 |
+
def _format_tool_result(self, tool_result: Dict) -> str:
|
| 316 |
+
"""Format tool result for feeding back to LLM"""
|
| 317 |
+
result = tool_result.get("result", {})
|
| 318 |
+
|
| 319 |
+
# Special handling for purchased events list
|
| 320 |
+
if isinstance(result, list):
|
| 321 |
+
print(f"\n🔍 Formatting {len(result)} purchased events for LLM")
|
| 322 |
+
if not result:
|
| 323 |
+
return "User has not purchased any events yet."
|
| 324 |
+
|
| 325 |
+
# Format each event clearly
|
| 326 |
+
formatted_events = []
|
| 327 |
+
for i, event in enumerate(result, 1):
|
| 328 |
+
event_info = []
|
| 329 |
+
event_info.append(f"Event {i}:")
|
| 330 |
+
|
| 331 |
+
# Extract key fields
|
| 332 |
+
if 'eventName' in event:
|
| 333 |
+
event_info.append(f" Name: {event['eventName']}")
|
| 334 |
+
if 'eventCode' in event:
|
| 335 |
+
event_info.append(f" Code: {event['eventCode']}")
|
| 336 |
+
if '_id' in event:
|
| 337 |
+
event_info.append(f" ID: {event['_id']}")
|
| 338 |
+
if 'startTimeEventTime' in event:
|
| 339 |
+
event_info.append(f" Date: {event['startTimeEventTime']}")
|
| 340 |
+
|
| 341 |
+
formatted_events.append("\n".join(event_info))
|
| 342 |
+
|
| 343 |
+
formatted = "User's Purchased Events:\n\n" + "\n\n".join(formatted_events)
|
| 344 |
+
print(f"📤 Sending to LLM:\n{formatted}")
|
| 345 |
+
return formatted
|
| 346 |
+
|
| 347 |
+
# Default formatting for other results
|
| 348 |
+
if isinstance(result, dict):
|
| 349 |
+
# Pretty print key info
|
| 350 |
+
formatted = []
|
| 351 |
+
for key, value in result.items():
|
| 352 |
+
if key not in ["success", "error"]:
|
| 353 |
+
formatted.append(f"{key}: {value}")
|
| 354 |
+
return "\n".join(formatted)
|
| 355 |
+
|
| 356 |
+
return str(result)
|
| 357 |
+
|
| 358 |
+
async def _execute_rag_search(self, query_params: Dict) -> str:
|
| 359 |
+
"""
|
| 360 |
+
Execute RAG search for event discovery
|
| 361 |
+
Called when LLM wants to search_events
|
| 362 |
+
"""
|
| 363 |
+
query = query_params.get("query", "")
|
| 364 |
+
vibe = query_params.get("vibe", "")
|
| 365 |
+
|
| 366 |
+
# Build search query
|
| 367 |
+
search_text = f"{query} {vibe}".strip()
|
| 368 |
+
|
| 369 |
+
print(f"🔍 RAG Search: {search_text}")
|
| 370 |
+
|
| 371 |
+
# Use embedding + qdrant
|
| 372 |
+
embedding = self.embedding_service.encode_text(search_text)
|
| 373 |
+
results = self.qdrant_service.search(
|
| 374 |
+
query_embedding=embedding,
|
| 375 |
+
limit=5
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Format results
|
| 379 |
+
formatted = []
|
| 380 |
+
for i, result in enumerate(results, 1):
|
| 381 |
+
# Result is a dict with keys: id, score, payload
|
| 382 |
+
payload = result.get("payload", {})
|
| 383 |
+
texts = payload.get("texts", [])
|
| 384 |
+
text = texts[0] if texts else ""
|
| 385 |
+
event_id = payload.get("id_use", "")
|
| 386 |
+
|
| 387 |
+
formatted.append(f"{i}. {text[:100]}... (ID: {event_id})")
|
| 388 |
+
|
| 389 |
+
return "\n".join(formatted) if formatted else "Không tìm thấy sự kiện phù hợp."
|
| 390 |
+
|
| 391 |
+
def _parse_tool_call(self, llm_response: str) -> Optional[Dict]:
|
| 392 |
+
"""
|
| 393 |
+
Parse LLM response to detect tool calls using structured JSON
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
{"tool_name": "...", "arguments": {...}} or None
|
| 397 |
+
"""
|
| 398 |
+
import json
|
| 399 |
+
import re
|
| 400 |
+
|
| 401 |
+
# Method 1: Look for JSON code block
|
| 402 |
+
json_match = re.search(r'```json\s*(\{.*?\})\s*```', llm_response, re.DOTALL)
|
| 403 |
+
if json_match:
|
| 404 |
+
try:
|
| 405 |
+
data = json.loads(json_match.group(1))
|
| 406 |
+
return self._extract_tool_from_json(data)
|
| 407 |
+
except json.JSONDecodeError:
|
| 408 |
+
pass
|
| 409 |
+
|
| 410 |
+
# Method 2: Look for inline JSON object
|
| 411 |
+
# Find all potential JSON objects
|
| 412 |
+
json_objects = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', llm_response)
|
| 413 |
+
for json_str in json_objects:
|
| 414 |
+
try:
|
| 415 |
+
data = json.loads(json_str)
|
| 416 |
+
tool_call = self._extract_tool_from_json(data)
|
| 417 |
+
if tool_call:
|
| 418 |
+
return tool_call
|
| 419 |
+
except json.JSONDecodeError:
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
# Method 3: Nested JSON (for complex structures)
|
| 423 |
+
try:
|
| 424 |
+
# Find outermost curly braces
|
| 425 |
+
if '{' in llm_response and '}' in llm_response:
|
| 426 |
+
start = llm_response.find('{')
|
| 427 |
+
# Find matching closing brace
|
| 428 |
+
count = 0
|
| 429 |
+
for i, char in enumerate(llm_response[start:], start):
|
| 430 |
+
if char == '{':
|
| 431 |
+
count += 1
|
| 432 |
+
elif char == '}':
|
| 433 |
+
count -= 1
|
| 434 |
+
if count == 0:
|
| 435 |
+
json_str = llm_response[start:i+1]
|
| 436 |
+
data = json.loads(json_str)
|
| 437 |
+
return self._extract_tool_from_json(data)
|
| 438 |
+
except (json.JSONDecodeError, ValueError):
|
| 439 |
+
pass
|
| 440 |
+
|
| 441 |
+
return None
|
| 442 |
+
|
| 443 |
+
def _extract_tool_from_json(self, data: dict) -> Optional[Dict]:
|
| 444 |
+
"""
|
| 445 |
+
Extract tool call information from parsed JSON
|
| 446 |
+
|
| 447 |
+
Supports multiple formats:
|
| 448 |
+
- {"tool_call": "search_events", "arguments": {...}}
|
| 449 |
+
- {"function": "search_events", "parameters": {...}}
|
| 450 |
+
- {"name": "search_events", "args": {...}}
|
| 451 |
+
"""
|
| 452 |
+
# Format 1: tool_call + arguments
|
| 453 |
+
if "tool_call" in data and isinstance(data["tool_call"], str):
|
| 454 |
+
return {
|
| 455 |
+
"tool_name": data["tool_call"],
|
| 456 |
+
"arguments": data.get("arguments", {})
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
# Format 2: function + parameters
|
| 460 |
+
if "function" in data:
|
| 461 |
+
return {
|
| 462 |
+
"tool_name": data["function"],
|
| 463 |
+
"arguments": data.get("parameters", data.get("arguments", {}))
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
# Format 3: name + args
|
| 467 |
+
if "name" in data:
|
| 468 |
+
return {
|
| 469 |
+
"tool_name": data["name"],
|
| 470 |
+
"arguments": data.get("args", data.get("arguments", {}))
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
# Format 4: Direct tool name as key
|
| 474 |
+
valid_tools = ["search_events", "get_event_details", "get_purchased_events", "save_feedback", "save_lead"]
|
| 475 |
+
for tool in valid_tools:
|
| 476 |
+
if tool in data:
|
| 477 |
+
return {
|
| 478 |
+
"tool_name": tool,
|
| 479 |
+
"arguments": data[tool] if isinstance(data[tool], dict) else {}
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
return None
|
| 483 |
+
|
| 484 |
+
def _clean_response(self, response: str) -> str:
|
| 485 |
+
"""Remove JSON artifacts from final response"""
|
| 486 |
+
# Remove JSON blocks
|
| 487 |
+
if "```json" in response:
|
| 488 |
+
response = response.split("```json")[0]
|
| 489 |
+
if "```" in response:
|
| 490 |
+
response = response.split("```")[0]
|
| 491 |
+
|
| 492 |
+
# Remove tool call markers
|
| 493 |
+
if "{" in response and "tool_call" in response:
|
| 494 |
+
# Find the last natural sentence before JSON
|
| 495 |
+
lines = response.split("\n")
|
| 496 |
+
cleaned = []
|
| 497 |
+
for line in lines:
|
| 498 |
+
if "{" in line and "tool_call" in line:
|
| 499 |
+
break
|
| 500 |
+
cleaned.append(line)
|
| 501 |
+
response = "\n".join(cleaned)
|
| 502 |
+
|
| 503 |
+
return response.strip()
|
app.py
CHANGED
|
@@ -1,420 +1,47 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
Generates relevant tags, keywords, and categories from event information
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
-
from fastapi import FastAPI, HTTPException
|
| 7 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
-
from pydantic import BaseModel
|
| 9 |
-
from typing import Optional, List
|
| 10 |
-
from datetime import datetime
|
| 11 |
import os
|
| 12 |
-
|
| 13 |
-
import
|
| 14 |
-
# Initialize FastAPI
|
| 15 |
-
app = FastAPI(
|
| 16 |
-
title="Event Tags Generator API",
|
| 17 |
-
description="AI-powered automatic tag generation for events using LLM",
|
| 18 |
-
version="1.0.0"
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
# CORS middleware
|
| 22 |
-
app.add_middleware(
|
| 23 |
-
CORSMiddleware,
|
| 24 |
-
allow_origins=["*"],
|
| 25 |
-
allow_credentials=True,
|
| 26 |
-
allow_methods=["*"],
|
| 27 |
-
allow_headers=["*"],
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# Hugging Face token
|
| 31 |
-
hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 32 |
-
if hf_token:
|
| 33 |
-
print("✓ Hugging Face token configured")
|
| 34 |
-
else:
|
| 35 |
-
print("⚠ Warning: No HUGGINGFACE_TOKEN found. Set it in environment variable.")
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# Pydantic models
|
| 39 |
-
class EventTagsRequest(BaseModel):
|
| 40 |
-
event_name: str
|
| 41 |
-
category: str
|
| 42 |
-
short_description: str
|
| 43 |
-
detailed_description: str
|
| 44 |
-
max_tags: Optional[int] = 10
|
| 45 |
-
language: Optional[str] = "vi" # vi = Vietnamese, en = English
|
| 46 |
-
hf_token: Optional[str] = None
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class EventTagsResponse(BaseModel):
|
| 50 |
-
event_name: str
|
| 51 |
-
generated_tags: List[str]
|
| 52 |
-
primary_category: str
|
| 53 |
-
secondary_categories: List[str]
|
| 54 |
-
keywords: List[str]
|
| 55 |
-
hashtags: List[str]
|
| 56 |
-
target_audience: List[str]
|
| 57 |
-
sentiment: str
|
| 58 |
-
confidence_score: float
|
| 59 |
-
generation_time: str
|
| 60 |
-
model_used: str
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@app.get("/")
|
| 64 |
-
async def root():
|
| 65 |
-
"""API Information"""
|
| 66 |
-
return {
|
| 67 |
-
"status": "running",
|
| 68 |
-
"service": "Event Tags Generator API",
|
| 69 |
-
"version": "1.0.0",
|
| 70 |
-
"description": "Generate tags, keywords, categories automatically from event info",
|
| 71 |
-
"endpoints": {
|
| 72 |
-
"POST /generate-tags": {
|
| 73 |
-
"description": "Generate tags from event information",
|
| 74 |
-
"request_body": {
|
| 75 |
-
"event_name": "string - Tên sự kiện",
|
| 76 |
-
"category": "string - Danh mục (âm nhạc, thể thao, công nghệ...)",
|
| 77 |
-
"short_description": "string - Mô tả ngắn (1-2 câu)",
|
| 78 |
-
"detailed_description": "string - Mô tả chi tiết",
|
| 79 |
-
"max_tags": "integer (optional, default: 10) - Số lượng tags tối đa",
|
| 80 |
-
"language": "string (optional, default: 'vi') - Ngôn ngữ output",
|
| 81 |
-
"hf_token": "string (optional) - Hugging Face token"
|
| 82 |
-
},
|
| 83 |
-
"response": {
|
| 84 |
-
"generated_tags": "array - Danh sách tags",
|
| 85 |
-
"primary_category": "string - Danh mục chính",
|
| 86 |
-
"secondary_categories": "array - Danh mục phụ",
|
| 87 |
-
"keywords": "array - Keywords SEO",
|
| 88 |
-
"hashtags": "array - Social media hashtags",
|
| 89 |
-
"target_audience": "array - Đối tượng mục tiêu",
|
| 90 |
-
"sentiment": "string - Cảm xúc (positive/neutral/negative)",
|
| 91 |
-
"confidence_score": "float - Độ tin cậy (0-1)"
|
| 92 |
-
},
|
| 93 |
-
"example": {
|
| 94 |
-
"request": {
|
| 95 |
-
"event_name": "Vietnam Music Festival 2025",
|
| 96 |
-
"category": "Âm nhạc",
|
| 97 |
-
"short_description": "Lễ hội âm nhạc quốc tế lớn nhất Việt Nam",
|
| 98 |
-
"detailed_description": "Sự kiện quy tụ các nghệ sĩ nổi tiếng trong nước và quốc tế..."
|
| 99 |
-
},
|
| 100 |
-
"response": {
|
| 101 |
-
"generated_tags": ["âm nhạc", "festival", "concert", "việt nam", "quốc tế"],
|
| 102 |
-
"hashtags": ["#VietnamMusicFest", "#MusicFestival2025", "#LiveMusic"]
|
| 103 |
-
}
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
},
|
| 107 |
-
"usage": "POST /generate-tags with event information in JSON body"
|
| 108 |
-
}
|
| 109 |
|
| 110 |
-
|
| 111 |
-
def
|
| 112 |
-
event_name: str,
|
| 113 |
-
category: str,
|
| 114 |
-
short_desc: str,
|
| 115 |
-
detailed_desc: str,
|
| 116 |
-
max_tags: int,
|
| 117 |
-
language: str
|
| 118 |
-
) -> str:
|
| 119 |
"""
|
| 120 |
-
|
| 121 |
"""
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
prompt = f"""You are an expert AI system specialized in event marketing, SEO, and content categorization. Your task is to analyze event information and generate comprehensive, relevant tags and metadata.
|
| 126 |
-
**EVENT INFORMATION:**
|
| 127 |
-
• Event Name: {event_name}
|
| 128 |
-
• Primary Category: {category}
|
| 129 |
-
• Short Description: {short_desc}
|
| 130 |
-
• Detailed Description: {detailed_desc}
|
| 131 |
-
**YOUR TASK:**
|
| 132 |
-
Analyze the event information above and generate the following {lang_instruction}:
|
| 133 |
-
1. **TAGS** ({max_tags} tags maximum):
|
| 134 |
-
- Generate specific, relevant, searchable tags
|
| 135 |
-
- Include event type, theme, activities, location references
|
| 136 |
-
- Mix broad and specific tags for better discoverability
|
| 137 |
-
- Use lowercase, single words or short phrases
|
| 138 |
-
- Example format: âm nhạc, festival, concert, outdoor, hà nội
|
| 139 |
-
2. **PRIMARY CATEGORY** (1 category):
|
| 140 |
-
- The main category that best describes this event
|
| 141 |
-
- Choose from: Âm nhạc, Thể thao, Công nghệ, Nghệ thuật, Ẩm thực, Giáo dục, Kinh doanh, Du lịch, Giải trí, Khác
|
| 142 |
-
3. **SECONDARY CATEGORIES** (2-3 categories):
|
| 143 |
-
- Additional relevant categories
|
| 144 |
-
- Help with cross-categorization
|
| 145 |
-
4. **KEYWORDS** (5-8 keywords):
|
| 146 |
-
- SEO-optimized keywords for search engines
|
| 147 |
-
- Include long-tail keywords
|
| 148 |
-
- Example: "lễ hội âm nhạc hà nội", "concert quốc tế việt nam"
|
| 149 |
-
5. **HASHTAGS** (5-7 hashtags):
|
| 150 |
-
- Social media friendly hashtags
|
| 151 |
-
- Mix of popular and unique hashtags
|
| 152 |
-
- Example: #VietnamMusicFest, #LiveMusic, #HanoiEvents
|
| 153 |
-
6. **TARGET AUDIENCE** (2-4 audience groups):
|
| 154 |
-
- Who would be interested in this event?
|
| 155 |
-
- Example: Giới trẻ, Gia đình, Dân văn phòng, Sinh viên
|
| 156 |
-
7. **SENTIMENT** (one word):
|
| 157 |
-
- Overall emotion/feeling: positive, neutral, or negative
|
| 158 |
-
- Based on event description tone
|
| 159 |
-
**OUTPUT FORMAT (JSON-like structure):**
|
| 160 |
-
TAGS: tag1, tag2, tag3, ...
|
| 161 |
-
PRIMARY_CATEGORY: category_name
|
| 162 |
-
SECONDARY_CATEGORIES: cat1, cat2, cat3
|
| 163 |
-
KEYWORDS: keyword1, keyword2, keyword3, ...
|
| 164 |
-
HASHTAGS: #tag1, #tag2, #tag3, ...
|
| 165 |
-
TARGET_AUDIENCE: audience1, audience2, audience3
|
| 166 |
-
SENTIMENT: positive/neutral/negative
|
| 167 |
-
**IMPORTANT GUIDELINES:**
|
| 168 |
-
- Be specific and relevant to the event
|
| 169 |
-
- Use terms people would actually search for
|
| 170 |
-
- Balance between popular and niche terms
|
| 171 |
-
- Consider SEO and social media best practices
|
| 172 |
-
- Keep tags concise and meaningful
|
| 173 |
-
- Generate output {lang_instruction}
|
| 174 |
-
Now, analyze the event and generate the metadata:"""
|
| 175 |
-
|
| 176 |
-
return prompt
|
| 177 |
|
|
|
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
result = {
|
| 186 |
-
"generated_tags": [],
|
| 187 |
-
"primary_category": "",
|
| 188 |
-
"secondary_categories": [],
|
| 189 |
-
"keywords": [],
|
| 190 |
-
"hashtags": [],
|
| 191 |
-
"target_audience": [],
|
| 192 |
-
"sentiment": "neutral"
|
| 193 |
-
}
|
| 194 |
-
|
| 195 |
-
lines = response_text.strip().split('\n')
|
| 196 |
-
|
| 197 |
-
for line in lines:
|
| 198 |
-
line = line.strip()
|
| 199 |
-
if not line:
|
| 200 |
-
continue
|
| 201 |
-
|
| 202 |
-
# Parse TAGS
|
| 203 |
-
if line.upper().startswith('TAGS:'):
|
| 204 |
-
tags_text = line.split(':', 1)[1].strip()
|
| 205 |
-
tags = [t.strip().lower() for t in tags_text.split(',') if t.strip()]
|
| 206 |
-
result["generated_tags"] = tags[:max_tags]
|
| 207 |
-
|
| 208 |
-
# Parse PRIMARY_CATEGORY
|
| 209 |
-
elif line.upper().startswith('PRIMARY_CATEGORY:'):
|
| 210 |
-
result["primary_category"] = line.split(':', 1)[1].strip()
|
| 211 |
-
|
| 212 |
-
# Parse SECONDARY_CATEGORIES
|
| 213 |
-
elif line.upper().startswith('SECONDARY_CATEGORIES:'):
|
| 214 |
-
cats_text = line.split(':', 1)[1].strip()
|
| 215 |
-
result["secondary_categories"] = [c.strip() for c in cats_text.split(',') if c.strip()]
|
| 216 |
-
|
| 217 |
-
# Parse KEYWORDS
|
| 218 |
-
elif line.upper().startswith('KEYWORDS:'):
|
| 219 |
-
kw_text = line.split(':', 1)[1].strip()
|
| 220 |
-
result["keywords"] = [k.strip() for k in kw_text.split(',') if k.strip()]
|
| 221 |
-
|
| 222 |
-
# Parse HASHTAGS
|
| 223 |
-
elif line.upper().startswith('HASHTAGS:'):
|
| 224 |
-
ht_text = line.split(':', 1)[1].strip()
|
| 225 |
-
hashtags = [h.strip() for h in ht_text.split(',') if h.strip()]
|
| 226 |
-
# Ensure hashtags start with #
|
| 227 |
-
result["hashtags"] = [h if h.startswith('#') else f"#{h}" for h in hashtags]
|
| 228 |
-
|
| 229 |
-
# Parse TARGET_AUDIENCE
|
| 230 |
-
elif line.upper().startswith('TARGET_AUDIENCE:'):
|
| 231 |
-
aud_text = line.split(':', 1)[1].strip()
|
| 232 |
-
result["target_audience"] = [a.strip() for a in aud_text.split(',') if a.strip()]
|
| 233 |
-
|
| 234 |
-
# Parse SENTIMENT
|
| 235 |
-
elif line.upper().startswith('SENTIMENT:'):
|
| 236 |
-
sentiment = line.split(':', 1)[1].strip().lower()
|
| 237 |
-
if sentiment in ['positive', 'neutral', 'negative']:
|
| 238 |
-
result["sentiment"] = sentiment
|
| 239 |
-
|
| 240 |
-
return result
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
Generate comprehensive tags and metadata for an event
|
| 247 |
-
|
| 248 |
-
This endpoint uses advanced LLM prompting to generate:
|
| 249 |
-
- Relevant tags for searchability
|
| 250 |
-
- Category classification
|
| 251 |
-
- SEO keywords
|
| 252 |
-
- Social media hashtags
|
| 253 |
-
- Target audience identification
|
| 254 |
-
- Sentiment analysis
|
| 255 |
-
|
| 256 |
-
**Input:**
|
| 257 |
-
- event_name: Name of the event
|
| 258 |
-
- category: Primary category (music, sports, tech, etc.)
|
| 259 |
-
- short_description: Brief 1-2 sentence description
|
| 260 |
-
- detailed_description: Full event description with details
|
| 261 |
-
|
| 262 |
-
**Output:**
|
| 263 |
-
- Structured metadata ready for use in event management system
|
| 264 |
-
- All fields optimized for search and discovery
|
| 265 |
-
"""
|
| 266 |
-
|
| 267 |
-
try:
|
| 268 |
-
start_time = datetime.utcnow()
|
| 269 |
-
|
| 270 |
-
# Get token
|
| 271 |
-
token = request.hf_token or hf_token
|
| 272 |
-
|
| 273 |
-
if not token:
|
| 274 |
-
raise HTTPException(
|
| 275 |
-
status_code=401,
|
| 276 |
-
detail="HUGGINGFACE_TOKEN required. Set environment variable or pass in request body."
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
# Build powerful prompt
|
| 280 |
-
prompt = build_powerful_prompt(
|
| 281 |
-
event_name=request.event_name,
|
| 282 |
-
category=request.category,
|
| 283 |
-
short_desc=request.short_description,
|
| 284 |
-
detailed_desc=request.detailed_description,
|
| 285 |
-
max_tags=request.max_tags,
|
| 286 |
-
language=request.language
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
# Initialize HF client
|
| 290 |
-
client = InferenceClient(token=token)
|
| 291 |
-
|
| 292 |
-
# Try multiple models for best results
|
| 293 |
-
models_to_try = [
|
| 294 |
-
"microsoft/Phi-3-mini-4k-instruct",
|
| 295 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 296 |
-
"HuggingFaceH4/zephyr-7b-beta",
|
| 297 |
-
"meta-llama/Llama-3.2-3B-Instruct"
|
| 298 |
-
]
|
| 299 |
-
|
| 300 |
-
llm_response = ""
|
| 301 |
-
model_used = ""
|
| 302 |
-
last_error = None
|
| 303 |
-
|
| 304 |
-
for model_name in models_to_try:
|
| 305 |
-
try:
|
| 306 |
-
print(f"Trying model: {model_name}")
|
| 307 |
-
|
| 308 |
-
# Generate with LLM
|
| 309 |
-
llm_response = client.text_generation(
|
| 310 |
-
prompt,
|
| 311 |
-
model=model_name,
|
| 312 |
-
max_new_tokens=800,
|
| 313 |
-
temperature=0.7,
|
| 314 |
-
top_p=0.9,
|
| 315 |
-
do_sample=True,
|
| 316 |
-
return_full_text=False
|
| 317 |
-
)
|
| 318 |
-
|
| 319 |
-
if llm_response and len(llm_response.strip()) > 50:
|
| 320 |
-
model_used = model_name
|
| 321 |
-
print(f"✓ Success with {model_name}")
|
| 322 |
-
break
|
| 323 |
-
|
| 324 |
-
except Exception as model_error:
|
| 325 |
-
print(f"✗ Failed with {model_name}: {str(model_error)}")
|
| 326 |
-
last_error = model_error
|
| 327 |
-
continue
|
| 328 |
-
|
| 329 |
-
# Check if generation succeeded
|
| 330 |
-
if not llm_response or len(llm_response.strip()) < 50:
|
| 331 |
-
raise HTTPException(
|
| 332 |
-
status_code=500,
|
| 333 |
-
detail=f"All models failed. Last error: {str(last_error)}\n\nPlease check:\n1. Token has correct permissions\n2. Token is valid and not expired\n3. Try regenerating token"
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
# Parse LLM response into structured format
|
| 337 |
-
parsed_result = parse_llm_response(llm_response, request.max_tags)
|
| 338 |
-
|
| 339 |
-
# Calculate confidence score (basic heuristic)
|
| 340 |
-
confidence = 0.0
|
| 341 |
-
if parsed_result["generated_tags"]:
|
| 342 |
-
confidence += 0.3
|
| 343 |
-
if parsed_result["primary_category"]:
|
| 344 |
-
confidence += 0.2
|
| 345 |
-
if parsed_result["keywords"]:
|
| 346 |
-
confidence += 0.2
|
| 347 |
-
if parsed_result["hashtags"]:
|
| 348 |
-
confidence += 0.15
|
| 349 |
-
if parsed_result["target_audience"]:
|
| 350 |
-
confidence += 0.15
|
| 351 |
-
|
| 352 |
-
end_time = datetime.utcnow()
|
| 353 |
-
generation_time = (end_time - start_time).total_seconds()
|
| 354 |
-
|
| 355 |
-
# Build response
|
| 356 |
-
return EventTagsResponse(
|
| 357 |
-
event_name=request.event_name,
|
| 358 |
-
generated_tags=parsed_result["generated_tags"],
|
| 359 |
-
primary_category=parsed_result["primary_category"],
|
| 360 |
-
secondary_categories=parsed_result["secondary_categories"],
|
| 361 |
-
keywords=parsed_result["keywords"],
|
| 362 |
-
hashtags=parsed_result["hashtags"],
|
| 363 |
-
target_audience=parsed_result["target_audience"],
|
| 364 |
-
sentiment=parsed_result["sentiment"],
|
| 365 |
-
confidence_score=round(confidence, 2),
|
| 366 |
-
generation_time=f"{generation_time:.2f}s",
|
| 367 |
-
model_used=model_used.split('/')[-1] if model_used else "unknown"
|
| 368 |
-
)
|
| 369 |
-
|
| 370 |
-
except HTTPException:
|
| 371 |
-
raise
|
| 372 |
-
except Exception as e:
|
| 373 |
-
raise HTTPException(
|
| 374 |
-
status_code=500,
|
| 375 |
-
detail=f"Error generating tags: {str(e)}"
|
| 376 |
-
)
|
| 377 |
|
|
|
|
| 378 |
|
| 379 |
-
|
| 380 |
-
async def generate_tags_batch(events: List[EventTagsRequest]):
|
| 381 |
-
"""
|
| 382 |
-
Batch generate tags for multiple events
|
| 383 |
-
|
| 384 |
-
Useful for bulk processing or migrating existing events
|
| 385 |
-
"""
|
| 386 |
-
results = []
|
| 387 |
-
|
| 388 |
-
for event in events:
|
| 389 |
-
try:
|
| 390 |
-
result = await generate_tags(event)
|
| 391 |
-
results.append({
|
| 392 |
-
"event_name": event.event_name,
|
| 393 |
-
"success": True,
|
| 394 |
-
"data": result
|
| 395 |
-
})
|
| 396 |
-
except Exception as e:
|
| 397 |
-
results.append({
|
| 398 |
-
"event_name": event.event_name,
|
| 399 |
-
"success": False,
|
| 400 |
-
"error": str(e)
|
| 401 |
-
})
|
| 402 |
-
|
| 403 |
-
return {
|
| 404 |
-
"total": len(events),
|
| 405 |
-
"successful": sum(1 for r in results if r["success"]),
|
| 406 |
-
"failed": sum(1 for r in results if not r["success"]),
|
| 407 |
-
"results": results
|
| 408 |
-
}
|
| 409 |
|
|
|
|
|
|
|
| 410 |
|
|
|
|
|
|
|
| 411 |
|
| 412 |
if __name__ == "__main__":
|
| 413 |
-
import
|
| 414 |
-
uvicorn.run(
|
| 415 |
-
"app:app",
|
| 416 |
-
host="0.0.0.0",
|
| 417 |
-
port=int(os.environ.get("PORT", 7860)),
|
| 418 |
-
reload=False,
|
| 419 |
-
log_level="info"
|
| 420 |
-
)
|
|
|
|
| 1 |
"""
|
| 2 |
+
Hugging Face Spaces compatible app
|
|
|
|
| 3 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from main import app as fastapi_app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
# Gradio wrapper cho Hugging Face Spaces
|
| 9 |
+
def create_gradio_interface():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
+
Tạo Gradio interface để deploy trên Hugging Face Spaces
|
| 12 |
"""
|
| 13 |
+
with gr.Blocks(title="Event Social Media Embeddings API") as demo:
|
| 14 |
+
gr.Markdown("""
|
| 15 |
+
# 🔍 Event Social Media Embeddings API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
API để embeddings và search multimodal (text + images) với **Jina CLIP v2** + **Qdrant Cloud**
|
| 18 |
|
| 19 |
+
## 🌟 Features:
|
| 20 |
+
- ✅ Multimodal: Text + Image embeddings
|
| 21 |
+
- ✅ Tiếng Việt: 100% support
|
| 22 |
+
- ✅ High Performance: ONNX + HNSW
|
| 23 |
+
- ✅ Cloud: Qdrant Cloud
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
## 📡 API Endpoints:
|
| 26 |
+
- `POST /index` - Index data
|
| 27 |
+
- `POST /search` - Hybrid search
|
| 28 |
+
- `POST /search/text` - Text search
|
| 29 |
+
- `POST /search/image` - Image search
|
| 30 |
|
| 31 |
+
### 🔗 API Docs:
|
| 32 |
+
Truy cập `/docs` để xem API documentation đầy đủ
|
| 33 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
gr.Markdown("### API is running at the `/docs` endpoint")
|
| 36 |
|
| 37 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
# Mount FastAPI app
|
| 40 |
+
demo = create_gradio_interface()
|
| 41 |
|
| 42 |
+
# Wrap FastAPI với Gradio
|
| 43 |
+
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|
| 46 |
+
import uvicorn
|
| 47 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_index_pdfs.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch script to index PDF files into RAG knowledge base
|
| 3 |
+
Usage: python batch_index_pdfs.py <pdf_directory> [options]
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from pymongo import MongoClient
|
| 10 |
+
from embedding_service import JinaClipEmbeddingService
|
| 11 |
+
from qdrant_service import QdrantVectorService
|
| 12 |
+
from pdf_parser import PDFIndexer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def index_pdf_directory(
|
| 16 |
+
pdf_dir: str,
|
| 17 |
+
category: str = "user_guide",
|
| 18 |
+
force: bool = False
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Index all PDF files in a directory
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
pdf_dir: Directory containing PDF files
|
| 25 |
+
category: Category for the PDFs (default: "user_guide")
|
| 26 |
+
force: Force reindex even if already indexed (default: False)
|
| 27 |
+
"""
|
| 28 |
+
print("="*60)
|
| 29 |
+
print("PDF Batch Indexer")
|
| 30 |
+
print("="*60)
|
| 31 |
+
|
| 32 |
+
# Initialize services (same as main.py)
|
| 33 |
+
print("\n[1/5] Initializing services...")
|
| 34 |
+
embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
|
| 35 |
+
|
| 36 |
+
collection_name = os.getenv("COLLECTION_NAME", "event_social_media")
|
| 37 |
+
qdrant_service = QdrantVectorService(
|
| 38 |
+
collection_name=collection_name,
|
| 39 |
+
vector_size=embedding_service.get_embedding_dimension()
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# MongoDB
|
| 43 |
+
mongodb_uri = os.getenv("MONGODB_URI", "mongodb+srv://truongtn7122003:7KaI9OT5KTUxWjVI@truongtn7122003.xogin4q.mongodb.net/")
|
| 44 |
+
mongo_client = MongoClient(mongodb_uri)
|
| 45 |
+
db = mongo_client[os.getenv("MONGODB_DB_NAME", "chatbot_rag")]
|
| 46 |
+
documents_collection = db["documents"]
|
| 47 |
+
|
| 48 |
+
# Initialize PDF indexer
|
| 49 |
+
pdf_indexer = PDFIndexer(
|
| 50 |
+
embedding_service=embedding_service,
|
| 51 |
+
qdrant_service=qdrant_service,
|
| 52 |
+
documents_collection=documents_collection
|
| 53 |
+
)
|
| 54 |
+
print("✓ Services initialized")
|
| 55 |
+
|
| 56 |
+
# Find all PDF files
|
| 57 |
+
print(f"\n[2/5] Scanning directory: {pdf_dir}")
|
| 58 |
+
pdf_files = list(Path(pdf_dir).glob("*.pdf"))
|
| 59 |
+
|
| 60 |
+
if not pdf_files:
|
| 61 |
+
print("✗ No PDF files found in directory")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
print(f"✓ Found {len(pdf_files)} PDF file(s)")
|
| 65 |
+
|
| 66 |
+
# Index each PDF
|
| 67 |
+
print(f"\n[3/5] Indexing PDFs...")
|
| 68 |
+
indexed_count = 0
|
| 69 |
+
skipped_count = 0
|
| 70 |
+
error_count = 0
|
| 71 |
+
|
| 72 |
+
for i, pdf_path in enumerate(pdf_files, 1):
|
| 73 |
+
print(f"\n--- [{i}/{len(pdf_files)}] Processing: {pdf_path.name} ---")
|
| 74 |
+
|
| 75 |
+
# Generate document ID
|
| 76 |
+
doc_id = f"pdf_{pdf_path.stem}"
|
| 77 |
+
|
| 78 |
+
# Check if already indexed
|
| 79 |
+
if not force:
|
| 80 |
+
existing = documents_collection.find_one({"document_id": doc_id})
|
| 81 |
+
if existing:
|
| 82 |
+
print(f"⊘ Already indexed (use --force to reindex)")
|
| 83 |
+
skipped_count += 1
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
# Index PDF
|
| 88 |
+
metadata = {
|
| 89 |
+
'title': pdf_path.stem.replace('_', ' ').title(),
|
| 90 |
+
'category': category,
|
| 91 |
+
'source_file': str(pdf_path)
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
result = pdf_indexer.index_pdf(
|
| 95 |
+
pdf_path=str(pdf_path),
|
| 96 |
+
document_id=doc_id,
|
| 97 |
+
document_metadata=metadata
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
print(f"✓ Indexed: {result['chunks_indexed']} chunks")
|
| 101 |
+
indexed_count += 1
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"✗ Error: {str(e)}")
|
| 105 |
+
error_count += 1
|
| 106 |
+
|
| 107 |
+
# Summary
|
| 108 |
+
print("\n" + "="*60)
|
| 109 |
+
print("SUMMARY")
|
| 110 |
+
print("="*60)
|
| 111 |
+
print(f"Total PDFs found: {len(pdf_files)}")
|
| 112 |
+
print(f"✓ Successfully indexed: {indexed_count}")
|
| 113 |
+
print(f"⊘ Skipped (already indexed): {skipped_count}")
|
| 114 |
+
print(f"✗ Errors: {error_count}")
|
| 115 |
+
|
| 116 |
+
if indexed_count > 0:
|
| 117 |
+
print(f"\n✓ Knowledge base updated successfully!")
|
| 118 |
+
print(f"You can now chat with your chatbot about the content in these PDFs.")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
"""Main entry point"""
|
| 123 |
+
if len(sys.argv) < 2:
|
| 124 |
+
print("Usage: python batch_index_pdfs.py <pdf_directory> [--category=<category>] [--force]")
|
| 125 |
+
print("\nExample:")
|
| 126 |
+
print(" python batch_index_pdfs.py ./docs/guides")
|
| 127 |
+
print(" python batch_index_pdfs.py ./docs/guides --category=user_guide --force")
|
| 128 |
+
sys.exit(1)
|
| 129 |
+
|
| 130 |
+
pdf_dir = sys.argv[1]
|
| 131 |
+
|
| 132 |
+
if not os.path.isdir(pdf_dir):
|
| 133 |
+
print(f"Error: Directory not found: {pdf_dir}")
|
| 134 |
+
sys.exit(1)
|
| 135 |
+
|
| 136 |
+
# Parse options
|
| 137 |
+
category = "user_guide"
|
| 138 |
+
force = False
|
| 139 |
+
|
| 140 |
+
for arg in sys.argv[2:]:
|
| 141 |
+
if arg.startswith("--category="):
|
| 142 |
+
category = arg.split("=")[1]
|
| 143 |
+
elif arg == "--force":
|
| 144 |
+
force = True
|
| 145 |
+
|
| 146 |
+
# Index PDFs
|
| 147 |
+
index_pdf_directory(pdf_dir, category=category, force=force)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
main()
|
cag_service.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CAG Service (Cache-Augmented Generation)
|
| 3 |
+
Semantic caching layer for RAG system using Qdrant
|
| 4 |
+
|
| 5 |
+
This module implements intelligent caching to reduce latency and LLM costs
|
| 6 |
+
by serving semantically similar queries from cache.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional, Dict, Any, Tuple
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
import numpy as np
|
| 12 |
+
from qdrant_client import QdrantClient
|
| 13 |
+
from qdrant_client.models import (
|
| 14 |
+
Distance, VectorParams, PointStruct,
|
| 15 |
+
SearchParams, Filter, FieldCondition, MatchValue, Range
|
| 16 |
+
)
|
| 17 |
+
import uuid
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CAGService:
|
| 22 |
+
"""
|
| 23 |
+
Cache-Augmented Generation Service
|
| 24 |
+
|
| 25 |
+
Features:
|
| 26 |
+
- Semantic similarity-based cache lookup (cosine similarity)
|
| 27 |
+
- TTL (Time-To-Live) for automatic cache expiration
|
| 28 |
+
- Configurable similarity threshold
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
embedding_service,
|
| 34 |
+
qdrant_url: Optional[str] = None,
|
| 35 |
+
qdrant_api_key: Optional[str] = None,
|
| 36 |
+
cache_collection: str = "semantic_cache",
|
| 37 |
+
vector_size: int = 1024,
|
| 38 |
+
similarity_threshold: float = 0.9,
|
| 39 |
+
ttl_hours: int = 24
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Initialize CAG Service
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
embedding_service: Embedding service for query encoding
|
| 46 |
+
qdrant_url: Qdrant Cloud URL
|
| 47 |
+
qdrant_api_key: Qdrant API key
|
| 48 |
+
cache_collection: Collection name for cache
|
| 49 |
+
vector_size: Embedding dimension
|
| 50 |
+
similarity_threshold: Min similarity for cache hit (0-1)
|
| 51 |
+
ttl_hours: Cache entry lifetime in hours
|
| 52 |
+
"""
|
| 53 |
+
self.embedding_service = embedding_service
|
| 54 |
+
self.cache_collection = cache_collection
|
| 55 |
+
self.similarity_threshold = similarity_threshold
|
| 56 |
+
self.ttl_hours = ttl_hours
|
| 57 |
+
|
| 58 |
+
# Initialize Qdrant client
|
| 59 |
+
url = qdrant_url or os.getenv("QDRANT_URL")
|
| 60 |
+
api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
| 61 |
+
|
| 62 |
+
if not url or not api_key:
|
| 63 |
+
raise ValueError("QDRANT_URL and QDRANT_API_KEY required for CAG")
|
| 64 |
+
|
| 65 |
+
self.client = QdrantClient(url=url, api_key=api_key)
|
| 66 |
+
self.vector_size = vector_size
|
| 67 |
+
|
| 68 |
+
# Ensure cache collection exists
|
| 69 |
+
self._ensure_cache_collection()
|
| 70 |
+
|
| 71 |
+
print(f"✓ CAG Service initialized (cache: {cache_collection}, threshold: {similarity_threshold})")
|
| 72 |
+
|
| 73 |
+
def _ensure_cache_collection(self):
|
| 74 |
+
"""Create cache collection if it doesn't exist"""
|
| 75 |
+
collections = self.client.get_collections().collections
|
| 76 |
+
exists = any(c.name == self.cache_collection for c in collections)
|
| 77 |
+
|
| 78 |
+
if not exists:
|
| 79 |
+
print(f"Creating semantic cache collection: {self.cache_collection}")
|
| 80 |
+
self.client.create_collection(
|
| 81 |
+
collection_name=self.cache_collection,
|
| 82 |
+
vectors_config=VectorParams(
|
| 83 |
+
size=self.vector_size,
|
| 84 |
+
distance=Distance.COSINE
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
print("✓ Semantic cache collection created")
|
| 88 |
+
|
| 89 |
+
def check_cache(
|
| 90 |
+
self,
|
| 91 |
+
query: str
|
| 92 |
+
) -> Optional[Dict[str, Any]]:
|
| 93 |
+
"""
|
| 94 |
+
Check if query has a cached response
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
query: User query string
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Cached data if found (with response, context, metadata), None otherwise
|
| 101 |
+
"""
|
| 102 |
+
# Generate query embedding
|
| 103 |
+
query_embedding = self.embedding_service.encode_text(query)
|
| 104 |
+
|
| 105 |
+
if len(query_embedding.shape) > 1:
|
| 106 |
+
query_embedding = query_embedding.flatten()
|
| 107 |
+
|
| 108 |
+
# Search for similar queries in cache
|
| 109 |
+
search_result = self.client.query_points(
|
| 110 |
+
collection_name=self.cache_collection,
|
| 111 |
+
query=query_embedding.tolist(),
|
| 112 |
+
limit=1,
|
| 113 |
+
score_threshold=self.similarity_threshold,
|
| 114 |
+
with_payload=True
|
| 115 |
+
).points
|
| 116 |
+
|
| 117 |
+
if not search_result:
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
hit = search_result[0]
|
| 121 |
+
|
| 122 |
+
# Check TTL
|
| 123 |
+
cached_at = datetime.fromisoformat(hit.payload.get("cached_at"))
|
| 124 |
+
expires_at = cached_at + timedelta(hours=self.ttl_hours)
|
| 125 |
+
|
| 126 |
+
if datetime.utcnow() > expires_at:
|
| 127 |
+
# Cache expired, delete it
|
| 128 |
+
self.client.delete(
|
| 129 |
+
collection_name=self.cache_collection,
|
| 130 |
+
points_selector=[hit.id]
|
| 131 |
+
)
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
# Cache hit!
|
| 135 |
+
return {
|
| 136 |
+
"response": hit.payload.get("response"),
|
| 137 |
+
"context_used": hit.payload.get("context_used", []),
|
| 138 |
+
"rag_stats": hit.payload.get("rag_stats"),
|
| 139 |
+
"cached_query": hit.payload.get("original_query"),
|
| 140 |
+
"similarity_score": float(hit.score),
|
| 141 |
+
"cached_at": cached_at.isoformat(),
|
| 142 |
+
"cache_hit": True
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def save_to_cache(
|
| 146 |
+
self,
|
| 147 |
+
query: str,
|
| 148 |
+
response: str,
|
| 149 |
+
context_used: list,
|
| 150 |
+
rag_stats: Optional[Dict] = None
|
| 151 |
+
) -> str:
|
| 152 |
+
"""
|
| 153 |
+
Save query-response pair to cache
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
query: Original user query
|
| 157 |
+
response: Generated response
|
| 158 |
+
context_used: Retrieved context documents
|
| 159 |
+
rag_stats: RAG pipeline statistics
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Cache entry ID
|
| 163 |
+
"""
|
| 164 |
+
# Generate query embedding
|
| 165 |
+
query_embedding = self.embedding_service.encode_text(query)
|
| 166 |
+
|
| 167 |
+
if len(query_embedding.shape) > 1:
|
| 168 |
+
query_embedding = query_embedding.flatten()
|
| 169 |
+
|
| 170 |
+
# Create cache entry
|
| 171 |
+
cache_id = str(uuid.uuid4())
|
| 172 |
+
|
| 173 |
+
point = PointStruct(
|
| 174 |
+
id=cache_id,
|
| 175 |
+
vector=query_embedding.tolist(),
|
| 176 |
+
payload={
|
| 177 |
+
"original_query": query,
|
| 178 |
+
"response": response,
|
| 179 |
+
"context_used": context_used,
|
| 180 |
+
"rag_stats": rag_stats or {},
|
| 181 |
+
"cached_at": datetime.utcnow().isoformat(),
|
| 182 |
+
"cache_type": "semantic"
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Save to Qdrant
|
| 187 |
+
self.client.upsert(
|
| 188 |
+
collection_name=self.cache_collection,
|
| 189 |
+
points=[point]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return cache_id
|
| 193 |
+
|
| 194 |
+
def clear_cache(self) -> bool:
|
| 195 |
+
"""
|
| 196 |
+
Clear all cache entries
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Success status
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
# Delete and recreate collection
|
| 203 |
+
self.client.delete_collection(collection_name=self.cache_collection)
|
| 204 |
+
self._ensure_cache_collection()
|
| 205 |
+
print("✓ Semantic cache cleared")
|
| 206 |
+
return True
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"Error clearing cache: {e}")
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
| 212 |
+
"""
|
| 213 |
+
Get cache statistics
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Cache statistics (size, hit rate, etc.)
|
| 217 |
+
"""
|
| 218 |
+
try:
|
| 219 |
+
info = self.client.get_collection(collection_name=self.cache_collection)
|
| 220 |
+
return {
|
| 221 |
+
"total_entries": info.points_count,
|
| 222 |
+
"vectors_count": info.vectors_count,
|
| 223 |
+
"status": info.status,
|
| 224 |
+
"ttl_hours": self.ttl_hours,
|
| 225 |
+
"similarity_threshold": self.similarity_threshold
|
| 226 |
+
}
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f"Error getting cache stats: {e}")
|
| 229 |
+
return {}
|
conversation_service.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Service for Multi-turn Chat
|
| 3 |
+
Server-side session management
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pymongo.collection import Collection
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConversationService:
|
| 12 |
+
"""
|
| 13 |
+
Manages multi-turn conversation history với server-side session
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, mongo_collection: Collection, max_history: int = 10):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
mongo_collection: MongoDB collection for storing conversations
|
| 20 |
+
max_history: Maximum số messages giữ lại (sliding window)
|
| 21 |
+
"""
|
| 22 |
+
self.collection = mongo_collection
|
| 23 |
+
self.max_history = max_history
|
| 24 |
+
|
| 25 |
+
# Create indexes
|
| 26 |
+
self._ensure_indexes()
|
| 27 |
+
|
| 28 |
+
def _ensure_indexes(self):
|
| 29 |
+
"""Create necessary indexes"""
|
| 30 |
+
try:
|
| 31 |
+
self.collection.create_index("session_id", unique=True)
|
| 32 |
+
self.collection.create_index("user_id") # NEW: Index for user filtering
|
| 33 |
+
# Auto-delete sessions sau 7 ngày không dùng
|
| 34 |
+
self.collection.create_index(
|
| 35 |
+
"updated_at",
|
| 36 |
+
expireAfterSeconds=604800 # 7 days
|
| 37 |
+
)
|
| 38 |
+
print("✓ Conversation indexes created")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Conversation indexes already exist or error: {e}")
|
| 41 |
+
|
| 42 |
+
def create_session(self, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Create new conversation session
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
metadata: Additional metadata
|
| 48 |
+
user_id: User identifier (optional)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
session_id (UUID string)
|
| 52 |
+
"""
|
| 53 |
+
session_id = str(uuid.uuid4())
|
| 54 |
+
|
| 55 |
+
self.collection.insert_one({
|
| 56 |
+
"session_id": session_id,
|
| 57 |
+
"user_id": user_id, # NEW: Store user_id
|
| 58 |
+
"messages": [],
|
| 59 |
+
"scenario_state": None, # NEW: Scenario state
|
| 60 |
+
"metadata": metadata or {},
|
| 61 |
+
"created_at": datetime.utcnow(),
|
| 62 |
+
"updated_at": datetime.utcnow()
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
return session_id
|
| 66 |
+
|
| 67 |
+
def add_message(
|
| 68 |
+
self,
|
| 69 |
+
session_id: str,
|
| 70 |
+
role: str,
|
| 71 |
+
content: str,
|
| 72 |
+
metadata: Optional[Dict] = None
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Add message to conversation history
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
session_id: Session identifier
|
| 79 |
+
role: "user" or "assistant"
|
| 80 |
+
content: Message text
|
| 81 |
+
metadata: Additional info (rag_stats, tool_calls, etc.)
|
| 82 |
+
"""
|
| 83 |
+
message = {
|
| 84 |
+
"role": role,
|
| 85 |
+
"content": content,
|
| 86 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 87 |
+
"metadata": metadata or {}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Upsert: tạo session nếu chưa tồn tại
|
| 91 |
+
self.collection.update_one(
|
| 92 |
+
{"session_id": session_id},
|
| 93 |
+
{
|
| 94 |
+
"$push": {
|
| 95 |
+
"messages": {
|
| 96 |
+
"$each": [message],
|
| 97 |
+
"$slice": -self.max_history # Keep only last N messages
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
"$set": {"updated_at": datetime.utcnow()}
|
| 101 |
+
},
|
| 102 |
+
upsert=True
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def get_conversation_history(
|
| 106 |
+
self,
|
| 107 |
+
session_id: str,
|
| 108 |
+
limit: Optional[int] = None,
|
| 109 |
+
include_metadata: bool = False
|
| 110 |
+
) -> List[Dict]:
|
| 111 |
+
"""
|
| 112 |
+
Get conversation messages for LLM context
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
session_id: Session identifier
|
| 116 |
+
limit: Override max_history với số lượng tùy chỉnh
|
| 117 |
+
include_metadata: Include metadata trong response
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
List of messages in format: [{"role": "user", "content": "..."}, ...]
|
| 121 |
+
"""
|
| 122 |
+
session = self.collection.find_one({"session_id": session_id})
|
| 123 |
+
|
| 124 |
+
if not session:
|
| 125 |
+
return []
|
| 126 |
+
|
| 127 |
+
messages = session.get("messages", [])
|
| 128 |
+
|
| 129 |
+
# Limit to recent messages
|
| 130 |
+
if limit:
|
| 131 |
+
messages = messages[-limit:]
|
| 132 |
+
else:
|
| 133 |
+
messages = messages[-self.max_history:]
|
| 134 |
+
|
| 135 |
+
# Format for LLM
|
| 136 |
+
if include_metadata:
|
| 137 |
+
return messages
|
| 138 |
+
else:
|
| 139 |
+
return [
|
| 140 |
+
{
|
| 141 |
+
"role": msg["role"],
|
| 142 |
+
"content": msg["content"]
|
| 143 |
+
}
|
| 144 |
+
for msg in messages
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
def get_session_info(self, session_id: str) -> Optional[Dict]:
|
| 148 |
+
"""
|
| 149 |
+
Get session metadata
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Session info hoặc None nếu không tồn tại
|
| 153 |
+
"""
|
| 154 |
+
session = self.collection.find_one(
|
| 155 |
+
{"session_id": session_id},
|
| 156 |
+
{"_id": 0, "session_id": 1, "user_id": 1, "created_at": 1, "updated_at": 1, "metadata": 1}
|
| 157 |
+
)
|
| 158 |
+
return session
|
| 159 |
+
|
| 160 |
+
def clear_session(self, session_id: str) -> bool:
|
| 161 |
+
"""
|
| 162 |
+
Clear conversation history for session
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
True nếu xóa thành công, False nếu session không tồn tại
|
| 166 |
+
"""
|
| 167 |
+
result = self.collection.delete_one({"session_id": session_id})
|
| 168 |
+
return result.deleted_count > 0
|
| 169 |
+
|
| 170 |
+
def session_exists(self, session_id: str) -> bool:
|
| 171 |
+
"""
|
| 172 |
+
Check if session exists
|
| 173 |
+
"""
|
| 174 |
+
return self.collection.count_documents({"session_id": session_id}) > 0
|
| 175 |
+
|
| 176 |
+
def get_last_user_message(self, session_id: str) -> Optional[str]:
|
| 177 |
+
"""
|
| 178 |
+
Get the last user message in conversation
|
| 179 |
+
Useful for context extraction
|
| 180 |
+
"""
|
| 181 |
+
session = self.collection.find_one({"session_id": session_id})
|
| 182 |
+
if not session:
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
messages = session.get("messages", [])
|
| 186 |
+
# Tìm message cuối cùng từ user
|
| 187 |
+
for msg in reversed(messages):
|
| 188 |
+
if msg["role"] == "user":
|
| 189 |
+
return msg["content"]
|
| 190 |
+
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
def list_sessions(
|
| 194 |
+
self,
|
| 195 |
+
limit: int = 50,
|
| 196 |
+
skip: int = 0,
|
| 197 |
+
sort_by: str = "updated_at",
|
| 198 |
+
descending: bool = True,
|
| 199 |
+
user_id: Optional[str] = None # NEW: Filter by user
|
| 200 |
+
) -> List[Dict]:
|
| 201 |
+
"""
|
| 202 |
+
List all conversation sessions
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
limit: Maximum number of sessions to return
|
| 206 |
+
skip: Number of sessions to skip (for pagination)
|
| 207 |
+
sort_by: Field to sort by (created_at, updated_at)
|
| 208 |
+
descending: Sort in descending order
|
| 209 |
+
user_id: Filter sessions by user_id (optional)
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
List of session summaries
|
| 213 |
+
"""
|
| 214 |
+
sort_order = -1 if descending else 1
|
| 215 |
+
|
| 216 |
+
# Build query filter
|
| 217 |
+
query = {}
|
| 218 |
+
if user_id:
|
| 219 |
+
query["user_id"] = user_id
|
| 220 |
+
|
| 221 |
+
sessions = self.collection.find(
|
| 222 |
+
query, # Use query filter
|
| 223 |
+
{"_id": 0, "session_id": 1, "user_id": 1, "created_at": 1, "updated_at": 1, "metadata": 1}
|
| 224 |
+
).sort(sort_by, sort_order).skip(skip).limit(limit)
|
| 225 |
+
|
| 226 |
+
result = []
|
| 227 |
+
for session in sessions:
|
| 228 |
+
# Count messages
|
| 229 |
+
message_count = len(
|
| 230 |
+
self.collection.find_one({"session_id": session["session_id"]}, {"messages": 1})
|
| 231 |
+
.get("messages", [])
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
result.append({
|
| 235 |
+
"session_id": session["session_id"],
|
| 236 |
+
"user_id": session.get("user_id"), # NEW: Include user_id
|
| 237 |
+
"created_at": session["created_at"],
|
| 238 |
+
"updated_at": session["updated_at"],
|
| 239 |
+
"message_count": message_count,
|
| 240 |
+
"metadata": session.get("metadata", {})
|
| 241 |
+
})
|
| 242 |
+
|
| 243 |
+
return result
|
| 244 |
+
|
| 245 |
+
def count_sessions(self, user_id: Optional[str] = None) -> int:
|
| 246 |
+
"""
|
| 247 |
+
Get total number of sessions
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
user_id: Filter count by user_id (optional)
|
| 251 |
+
"""
|
| 252 |
+
query = {}
|
| 253 |
+
if user_id:
|
| 254 |
+
query["user_id"] = user_id
|
| 255 |
+
return self.collection.count_documents(query)
|
| 256 |
+
|
| 257 |
+
# ===== Scenario State Management =====
|
| 258 |
+
|
| 259 |
+
def get_scenario_state(self, session_id: str) -> Optional[Dict]:
|
| 260 |
+
"""
|
| 261 |
+
Get current scenario state for session
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
{
|
| 265 |
+
"active_scenario": "price_inquiry",
|
| 266 |
+
"scenario_step": 3,
|
| 267 |
+
"scenario_data": {...},
|
| 268 |
+
"last_activity": "..."
|
| 269 |
+
}
|
| 270 |
+
or None if no active scenario
|
| 271 |
+
"""
|
| 272 |
+
session = self.collection.find_one({"session_id": session_id})
|
| 273 |
+
if not session:
|
| 274 |
+
return None
|
| 275 |
+
return session.get("scenario_state")
|
| 276 |
+
|
| 277 |
+
def set_scenario_state(self, session_id: str, state: Dict):
|
| 278 |
+
"""
|
| 279 |
+
Set scenario state for session
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
session_id: Session ID
|
| 283 |
+
state: Scenario state dict
|
| 284 |
+
"""
|
| 285 |
+
self.collection.update_one(
|
| 286 |
+
{"session_id": session_id},
|
| 287 |
+
{
|
| 288 |
+
"$set": {
|
| 289 |
+
"scenario_state": state,
|
| 290 |
+
"updated_at": datetime.utcnow()
|
| 291 |
+
}
|
| 292 |
+
},
|
| 293 |
+
upsert=True
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def clear_scenario(self, session_id: str):
|
| 297 |
+
"""
|
| 298 |
+
Clear scenario state (end scenario)
|
| 299 |
+
"""
|
| 300 |
+
self.collection.update_one(
|
| 301 |
+
{"session_id": session_id},
|
| 302 |
+
{
|
| 303 |
+
"$set": {
|
| 304 |
+
"scenario_state": None,
|
| 305 |
+
"updated_at": datetime.utcnow()
|
| 306 |
+
}
|
| 307 |
+
}
|
| 308 |
+
)
|
embedding_service.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import AutoModel
|
| 5 |
+
from typing import Union, List
|
| 6 |
+
import io
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class JinaClipEmbeddingService:
|
| 10 |
+
"""
|
| 11 |
+
Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
|
| 12 |
+
Sử dụng AutoModel với trust_remote_code
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
|
| 16 |
+
"""
|
| 17 |
+
Initialize Jina CLIP v2 model
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
model_path: Path to model hoặc HuggingFace model name
|
| 21 |
+
"""
|
| 22 |
+
print(f"Loading Jina CLIP v2 model from {model_path}...")
|
| 23 |
+
|
| 24 |
+
# Load model với trust_remote_code
|
| 25 |
+
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
| 26 |
+
|
| 27 |
+
# Chuyển sang eval mode
|
| 28 |
+
self.model.eval()
|
| 29 |
+
|
| 30 |
+
# Sử dụng GPU nếu có
|
| 31 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
+
self.model.to(self.device)
|
| 33 |
+
|
| 34 |
+
print(f"✓ Loaded Jina CLIP v2 model on: {self.device}")
|
| 35 |
+
|
| 36 |
+
def encode_text(
|
| 37 |
+
self,
|
| 38 |
+
text: Union[str, List[str]],
|
| 39 |
+
truncate_dim: int = None,
|
| 40 |
+
normalize: bool = True
|
| 41 |
+
) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Encode text thành vector embeddings (hỗ trợ tiếng Việt)
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
text: Text hoặc list of texts (tiếng Việt)
|
| 47 |
+
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
|
| 48 |
+
normalize: Có normalize embeddings không
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
numpy array của embeddings
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(text, str):
|
| 54 |
+
text = [text]
|
| 55 |
+
|
| 56 |
+
# Jina CLIP v2 encode_text method
|
| 57 |
+
# Automatically handles tokenization internally
|
| 58 |
+
embeddings = self.model.encode_text(
|
| 59 |
+
text,
|
| 60 |
+
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Convert to numpy
|
| 64 |
+
if isinstance(embeddings, torch.Tensor):
|
| 65 |
+
embeddings = embeddings.cpu().detach().numpy()
|
| 66 |
+
|
| 67 |
+
# Normalize nếu cần
|
| 68 |
+
if normalize:
|
| 69 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 70 |
+
|
| 71 |
+
return embeddings
|
| 72 |
+
|
| 73 |
+
def encode_image(
|
| 74 |
+
self,
|
| 75 |
+
image: Union[Image.Image, bytes, List, str],
|
| 76 |
+
truncate_dim: int = None,
|
| 77 |
+
normalize: bool = True
|
| 78 |
+
) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Encode image thành vector embeddings
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
image: PIL Image, bytes, URL string, hoặc list of images
|
| 84 |
+
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
|
| 85 |
+
normalize: Có normalize embeddings không
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
numpy array của embeddings
|
| 89 |
+
"""
|
| 90 |
+
# Convert bytes to PIL Image nếu cần
|
| 91 |
+
if isinstance(image, bytes):
|
| 92 |
+
image = Image.open(io.BytesIO(image)).convert('RGB')
|
| 93 |
+
elif isinstance(image, list):
|
| 94 |
+
processed_images = []
|
| 95 |
+
for img in image:
|
| 96 |
+
if isinstance(img, bytes):
|
| 97 |
+
processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
|
| 98 |
+
elif isinstance(img, str):
|
| 99 |
+
# URL string - keep as is, Jina CLIP can handle URLs
|
| 100 |
+
processed_images.append(img)
|
| 101 |
+
else:
|
| 102 |
+
processed_images.append(img)
|
| 103 |
+
image = processed_images
|
| 104 |
+
elif not isinstance(image, list) and not isinstance(image, str):
|
| 105 |
+
# Single PIL Image
|
| 106 |
+
image = [image]
|
| 107 |
+
|
| 108 |
+
# Jina CLIP v2 encode_image method
|
| 109 |
+
# Supports PIL Images, file paths, or URLs
|
| 110 |
+
embeddings = self.model.encode_image(
|
| 111 |
+
image,
|
| 112 |
+
truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Convert to numpy
|
| 116 |
+
if isinstance(embeddings, torch.Tensor):
|
| 117 |
+
embeddings = embeddings.cpu().detach().numpy()
|
| 118 |
+
|
| 119 |
+
# Normalize nếu cần
|
| 120 |
+
if normalize:
|
| 121 |
+
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 122 |
+
|
| 123 |
+
return embeddings
|
| 124 |
+
|
| 125 |
+
def encode_multimodal(
|
| 126 |
+
self,
|
| 127 |
+
text: Union[str, List[str]] = None,
|
| 128 |
+
image: Union[Image.Image, bytes, List] = None,
|
| 129 |
+
truncate_dim: int = None,
|
| 130 |
+
normalize: bool = True
|
| 131 |
+
) -> np.ndarray:
|
| 132 |
+
"""
|
| 133 |
+
Encode cả text và image, trả về embeddings kết hợp
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
text: Text hoặc list of texts (tiếng Việt)
|
| 137 |
+
image: PIL Image, bytes, hoặc list of images
|
| 138 |
+
truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
|
| 139 |
+
normalize: Có normalize embeddings không
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
numpy array của embeddings
|
| 143 |
+
"""
|
| 144 |
+
embeddings = []
|
| 145 |
+
|
| 146 |
+
if text is not None:
|
| 147 |
+
text_emb = self.encode_text(text, truncate_dim=truncate_dim, normalize=False)
|
| 148 |
+
embeddings.append(text_emb)
|
| 149 |
+
|
| 150 |
+
if image is not None:
|
| 151 |
+
image_emb = self.encode_image(image, truncate_dim=truncate_dim, normalize=False)
|
| 152 |
+
embeddings.append(image_emb)
|
| 153 |
+
|
| 154 |
+
# Combine embeddings (average)
|
| 155 |
+
if len(embeddings) == 2:
|
| 156 |
+
# Average của text và image embeddings
|
| 157 |
+
combined = np.mean(embeddings, axis=0)
|
| 158 |
+
elif len(embeddings) == 1:
|
| 159 |
+
combined = embeddings[0]
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError("Phải cung cấp ít nhất text hoặc image")
|
| 162 |
+
|
| 163 |
+
# Normalize nếu cần
|
| 164 |
+
if normalize:
|
| 165 |
+
combined = combined / np.linalg.norm(combined, axis=1, keepdims=True)
|
| 166 |
+
|
| 167 |
+
return combined
|
| 168 |
+
|
| 169 |
+
def get_embedding_dimension(self) -> int:
|
| 170 |
+
"""
|
| 171 |
+
Trả về dimension của embeddings (1024 cho Jina CLIP v2)
|
| 172 |
+
"""
|
| 173 |
+
return 1024
|
feedback_tracking_service.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Tracking Service
|
| 3 |
+
Tracks which events users have already given feedback for
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Dict
|
| 6 |
+
from pymongo.collection import Collection
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FeedbackTrackingService:
|
| 11 |
+
"""
|
| 12 |
+
Track feedback status per user per event
|
| 13 |
+
Prevents redundant "check purchase history" calls
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, mongo_collection: Collection):
|
| 17 |
+
self.collection = mongo_collection
|
| 18 |
+
self._ensure_indexes()
|
| 19 |
+
|
| 20 |
+
def _ensure_indexes(self):
|
| 21 |
+
"""Create indexes for fast lookup"""
|
| 22 |
+
try:
|
| 23 |
+
# Compound index for quick lookup
|
| 24 |
+
self.collection.create_index([("user_id", 1), ("event_code", 1)], unique=True)
|
| 25 |
+
self.collection.create_index("user_id")
|
| 26 |
+
print("✓ Feedback tracking indexes created")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Feedback tracking indexes exist: {e}")
|
| 29 |
+
|
| 30 |
+
def has_given_feedback(self, user_id: str, event_code: str) -> bool:
|
| 31 |
+
"""
|
| 32 |
+
Check if user has already given feedback for this event
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
user_id: User ID
|
| 36 |
+
event_code: Event code
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
True if feedback already given, False otherwise
|
| 40 |
+
"""
|
| 41 |
+
result = self.collection.find_one({
|
| 42 |
+
"user_id": user_id,
|
| 43 |
+
"event_code": event_code,
|
| 44 |
+
"is_feedback": True
|
| 45 |
+
})
|
| 46 |
+
return result is not None
|
| 47 |
+
|
| 48 |
+
def mark_feedback_given(self, user_id: str, event_code: str, rating: int, comment: str = "") -> bool:
|
| 49 |
+
"""
|
| 50 |
+
Mark that user has given feedback for this event
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
user_id: User ID
|
| 54 |
+
event_code: Event code
|
| 55 |
+
rating: Rating given (1-5)
|
| 56 |
+
comment: Feedback comment
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
True if saved successfully
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
self.collection.update_one(
|
| 63 |
+
{
|
| 64 |
+
"user_id": user_id,
|
| 65 |
+
"event_code": event_code
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"$set": {
|
| 69 |
+
"is_feedback": True,
|
| 70 |
+
"rating": rating,
|
| 71 |
+
"comment": comment,
|
| 72 |
+
"feedback_date": datetime.utcnow(),
|
| 73 |
+
"updated_at": datetime.utcnow()
|
| 74 |
+
},
|
| 75 |
+
"$setOnInsert": {
|
| 76 |
+
"created_at": datetime.utcnow()
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
upsert=True
|
| 80 |
+
)
|
| 81 |
+
print(f"✅ Marked feedback: {user_id} → {event_code} (rating: {rating})")
|
| 82 |
+
return True
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"❌ Error marking feedback: {e}")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
def get_pending_events(self, user_id: str, purchased_events: list) -> list:
|
| 88 |
+
"""
|
| 89 |
+
Filter purchased events to only those without feedback
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
user_id: User ID
|
| 93 |
+
purchased_events: List of events user has purchased
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
List of events that need feedback
|
| 97 |
+
"""
|
| 98 |
+
pending = []
|
| 99 |
+
for event in purchased_events:
|
| 100 |
+
event_code = event.get("eventCode")
|
| 101 |
+
if event_code and not self.has_given_feedback(user_id, event_code):
|
| 102 |
+
pending.append(event)
|
| 103 |
+
return pending
|
main.py
ADDED
|
@@ -0,0 +1,1326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 2 |
+
from fastapi.responses import JSONResponse, StreamingResponse # Add StreamingResponse
|
| 3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import Optional, List, Dict
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pymongo import MongoClient
|
| 12 |
+
from huggingface_hub import InferenceClient
|
| 13 |
+
|
| 14 |
+
from embedding_service import JinaClipEmbeddingService
|
| 15 |
+
from qdrant_service import QdrantVectorService
|
| 16 |
+
from advanced_rag import AdvancedRAG
|
| 17 |
+
from cag_service import CAGService
|
| 18 |
+
from pdf_parser import PDFIndexer
|
| 19 |
+
from multimodal_pdf_parser import MultimodalPDFIndexer
|
| 20 |
+
from conversation_service import ConversationService
|
| 21 |
+
from tools_service import ToolsService
|
| 22 |
+
from agent_service import AgentService
|
| 23 |
+
from agent_chat_stream import agent_chat_stream # NEW: Agent Streaming
|
| 24 |
+
from feedback_tracking_service import FeedbackTrackingService # NEW: Feedback tracking
|
| 25 |
+
|
| 26 |
+
# Initialize FastAPI app
|
| 27 |
+
app = FastAPI(
|
| 28 |
+
title="Event Social Media Embeddings & ChatbotRAG API",
|
| 29 |
+
description="API để embeddings, search và ChatbotRAG với Jina CLIP v2 + Qdrant + MongoDB + LLM",
|
| 30 |
+
version="2.0.0"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# CORS middleware
|
| 34 |
+
app.add_middleware(
|
| 35 |
+
CORSMiddleware,
|
| 36 |
+
allow_origins=["*"],
|
| 37 |
+
allow_credentials=True,
|
| 38 |
+
allow_methods=["*"],
|
| 39 |
+
allow_headers=["*"],
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Initialize services
|
| 43 |
+
print("Initializing services...")
|
| 44 |
+
embedding_service = JinaClipEmbeddingService(model_path="jinaai/jina-clip-v2")
|
| 45 |
+
|
| 46 |
+
collection_name = os.getenv("COLLECTION_NAME", "event_social_media")
|
| 47 |
+
qdrant_service = QdrantVectorService(
|
| 48 |
+
collection_name=collection_name,
|
| 49 |
+
vector_size=embedding_service.get_embedding_dimension()
|
| 50 |
+
)
|
| 51 |
+
print(f"✓ Qdrant collection: {collection_name}")
|
| 52 |
+
|
| 53 |
+
# MongoDB connection
|
| 54 |
+
mongodb_uri = os.getenv("MONGODB_URI", "mongodb+srv://truongtn7122003:7KaI9OT5KTUxWjVI@truongtn7122003.xogin4q.mongodb.net/")
|
| 55 |
+
mongo_client = MongoClient(mongodb_uri)
|
| 56 |
+
db = mongo_client[os.getenv("MONGODB_DB_NAME", "chatbot_rag")]
|
| 57 |
+
documents_collection = db["documents"]
|
| 58 |
+
chat_history_collection = db["chat_history"]
|
| 59 |
+
print("✓ MongoDB connected")
|
| 60 |
+
|
| 61 |
+
# Hugging Face token
|
| 62 |
+
hf_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 63 |
+
if hf_token:
|
| 64 |
+
print("✓ Hugging Face token configured")
|
| 65 |
+
|
| 66 |
+
# Initialize Advanced RAG (Best Case 2025)
|
| 67 |
+
advanced_rag = AdvancedRAG(
|
| 68 |
+
embedding_service=embedding_service,
|
| 69 |
+
qdrant_service=qdrant_service
|
| 70 |
+
)
|
| 71 |
+
print("✓ Advanced RAG pipeline initialized (with Cross-Encoder)")
|
| 72 |
+
|
| 73 |
+
# Initialize CAG Service (Semantic Cache)
|
| 74 |
+
try:
|
| 75 |
+
cag_service = CAGService(
|
| 76 |
+
embedding_service=embedding_service,
|
| 77 |
+
cache_collection="semantic_cache",
|
| 78 |
+
vector_size=embedding_service.get_embedding_dimension(),
|
| 79 |
+
similarity_threshold=0.9,
|
| 80 |
+
ttl_hours=24
|
| 81 |
+
)
|
| 82 |
+
print("✓ CAG Service initialized (Semantic Caching enabled)")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Warning: CAG Service initialization failed: {e}")
|
| 85 |
+
print("Continuing without semantic caching...")
|
| 86 |
+
cag_service = None
|
| 87 |
+
|
| 88 |
+
# Initialize PDF Indexer
|
| 89 |
+
pdf_indexer = PDFIndexer(
|
| 90 |
+
embedding_service=embedding_service,
|
| 91 |
+
qdrant_service=qdrant_service,
|
| 92 |
+
documents_collection=documents_collection
|
| 93 |
+
)
|
| 94 |
+
print("✓ PDF Indexer initialized")
|
| 95 |
+
|
| 96 |
+
# Initialize Multimodal PDF Indexer
|
| 97 |
+
multimodal_pdf_indexer = MultimodalPDFIndexer(
|
| 98 |
+
embedding_service=embedding_service,
|
| 99 |
+
qdrant_service=qdrant_service,
|
| 100 |
+
documents_collection=documents_collection
|
| 101 |
+
)
|
| 102 |
+
print("✓ Multimodal PDF Indexer initialized")
|
| 103 |
+
|
| 104 |
+
# Initialize Conversation Service
|
| 105 |
+
conversations_collection = db["conversations"]
|
| 106 |
+
conversation_service = ConversationService(conversations_collection, max_history=10)
|
| 107 |
+
print("✓ Conversation Service initialized")
|
| 108 |
+
|
| 109 |
+
# Initialize Feedback Tracking Service
|
| 110 |
+
feedback_tracking_collection = db["feedback_tracking"]
|
| 111 |
+
feedback_tracking = FeedbackTrackingService(feedback_tracking_collection)
|
| 112 |
+
print("✓ Feedback Tracking Service initialized")
|
| 113 |
+
|
| 114 |
+
# Initialize Tools Service
|
| 115 |
+
tools_service = ToolsService(
|
| 116 |
+
base_url="https://hoalacrent.io.vn/api/v0",
|
| 117 |
+
feedback_tracking=feedback_tracking
|
| 118 |
+
)
|
| 119 |
+
print("✓ Tools Service initialized (Function Calling enabled)")
|
| 120 |
+
|
| 121 |
+
# Initialize Agent Service (Agentic Workflow)
|
| 122 |
+
agent_service = AgentService(
|
| 123 |
+
tools_service=tools_service,
|
| 124 |
+
embedding_service=embedding_service,
|
| 125 |
+
qdrant_service=qdrant_service,
|
| 126 |
+
advanced_rag=advanced_rag,
|
| 127 |
+
hf_token=hf_token,
|
| 128 |
+
feedback_tracking=feedback_tracking # Pass feedback tracking
|
| 129 |
+
)
|
| 130 |
+
print("✓ Agent Service initialized (Agentic Workflow enabled)")
|
| 131 |
+
|
| 132 |
+
print("✓ Services initialized successfully")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Pydantic models for embeddings
|
| 136 |
+
class SearchRequest(BaseModel):
|
| 137 |
+
text: Optional[str] = None
|
| 138 |
+
limit: int = 10
|
| 139 |
+
score_threshold: Optional[float] = None
|
| 140 |
+
text_weight: float = 0.5
|
| 141 |
+
image_weight: float = 0.5
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class SearchResponse(BaseModel):
|
| 145 |
+
id: str
|
| 146 |
+
confidence: float
|
| 147 |
+
metadata: dict
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class IndexResponse(BaseModel):
|
| 151 |
+
success: bool
|
| 152 |
+
id: str
|
| 153 |
+
message: str
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Pydantic models for ChatbotRAG
|
| 157 |
+
class ChatRequest(BaseModel):
|
| 158 |
+
message: str
|
| 159 |
+
session_id: Optional[str] = None # Multi-turn conversation
|
| 160 |
+
user_id: Optional[str] = None # User identifier for session tracking
|
| 161 |
+
access_token: Optional[str] = None # NEW: For authenticated API calls (feedback mode)
|
| 162 |
+
mode: str = "sales" # NEW: "sales" or "feedback" for agent selection
|
| 163 |
+
event_code: Optional[str] = None # NEW: For targeted feedback on specific event
|
| 164 |
+
use_rag: bool = True
|
| 165 |
+
top_k: int = 3
|
| 166 |
+
system_message: Optional[str] = """Bạn là trợ lý AI chuyên biệt cho hệ thống quản lý sự kiện và bán vé.
|
| 167 |
+
Vai trò của bạn là trả lời các câu hỏi CHÍNH XÁC dựa trên dữ liệu được cung cấp từ hệ thống.
|
| 168 |
+
|
| 169 |
+
Quy tắc tuyệt đối:
|
| 170 |
+
- CHỈ trả lời câu hỏi liên quan đến: events, social media posts, PDFs đã upload, và dữ liệu trong knowledge base
|
| 171 |
+
- KHÔNG trả lời câu hỏi ngoài phạm vi (tin tức, thời tiết, toán học, lập trình, tư vấn cá nhân, v.v.)
|
| 172 |
+
- Nếu câu hỏi nằm ngoài phạm vi: BẮT BUỘC trả lời "Chúng tôi không thể trả lời câu hỏi này vì nó nằm ngoài vùng application xử lí."
|
| 173 |
+
- Luôn ưu tiên thông tin từ context được cung cấp"""
|
| 174 |
+
max_tokens: int = 512
|
| 175 |
+
temperature: float = 0.7
|
| 176 |
+
top_p: float = 0.95
|
| 177 |
+
hf_token: Optional[str] = None
|
| 178 |
+
# Advanced RAG options
|
| 179 |
+
use_advanced_rag: bool = True
|
| 180 |
+
use_query_expansion: bool = True
|
| 181 |
+
use_reranking: bool = False # Disabled - Cross-Encoder not good for Vietnamese
|
| 182 |
+
use_compression: bool = True
|
| 183 |
+
score_threshold: float = 0.5
|
| 184 |
+
# Function calling
|
| 185 |
+
enable_tools: bool = True # Enable API tool calling
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ChatResponse(BaseModel):
|
| 189 |
+
response: str
|
| 190 |
+
context_used: List[Dict]
|
| 191 |
+
timestamp: str
|
| 192 |
+
rag_stats: Optional[Dict] = None # Stats from advanced RAG pipeline
|
| 193 |
+
session_id: Optional[str] = None # Session identifier for multi-turn (auto-generated if not provided)
|
| 194 |
+
tool_calls: Optional[List[Dict]] = None # Track API calls made
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class AddDocumentRequest(BaseModel):
|
| 198 |
+
text: str
|
| 199 |
+
metadata: Optional[Dict] = None
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class AddDocumentResponse(BaseModel):
|
| 203 |
+
success: bool
|
| 204 |
+
doc_id: str
|
| 205 |
+
message: str
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@app.get("/")
|
| 209 |
+
async def root():
|
| 210 |
+
"""Health check endpoint with comprehensive API documentation"""
|
| 211 |
+
return {
|
| 212 |
+
"status": "running",
|
| 213 |
+
"service": "ChatbotRAG API",
|
| 214 |
+
"version": "2.0.0",
|
| 215 |
+
"vector_db": "Qdrant",
|
| 216 |
+
"document_db": "MongoDB",
|
| 217 |
+
"endpoints": {
|
| 218 |
+
"chatbot_rag": {
|
| 219 |
+
"API endpoint": "https://minhvtt-ChatbotRAG.hf.space/",
|
| 220 |
+
"POST /chat": {
|
| 221 |
+
"description": "Chat với AI sử dụng RAG (Retrieval-Augmented Generation)",
|
| 222 |
+
"request": {
|
| 223 |
+
"method": "POST",
|
| 224 |
+
"content_type": "application/json",
|
| 225 |
+
"body": {
|
| 226 |
+
"message": "string (required) - User message/question",
|
| 227 |
+
"use_rag": "boolean (optional, default: true) - Enable RAG context retrieval",
|
| 228 |
+
"top_k": "integer (optional, default: 3) - Number of context documents to retrieve",
|
| 229 |
+
"system_message": "string (optional) - Custom system prompt",
|
| 230 |
+
"max_tokens": "integer (optional, default: 512) - Max response length",
|
| 231 |
+
"temperature": "float (optional, default: 0.7, range: 0-1) - Creativity level",
|
| 232 |
+
"top_p": "float (optional, default: 0.95) - Nucleus sampling",
|
| 233 |
+
"hf_token": "string (optional) - Hugging Face token (fallback to env)"
|
| 234 |
+
}
|
| 235 |
+
},
|
| 236 |
+
"response": {
|
| 237 |
+
"response": "string - AI generated response",
|
| 238 |
+
"context_used": [
|
| 239 |
+
{
|
| 240 |
+
"id": "string - Document ID",
|
| 241 |
+
"confidence": "float - Relevance score",
|
| 242 |
+
"metadata": {
|
| 243 |
+
"text": "string - Retrieved context"
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
],
|
| 247 |
+
"timestamp": "string - ISO 8601 timestamp"
|
| 248 |
+
},
|
| 249 |
+
"example_request": {
|
| 250 |
+
"message": "Dao có nguy hiểm không?",
|
| 251 |
+
"use_rag": True,
|
| 252 |
+
"top_k": 3,
|
| 253 |
+
"temperature": 0.7
|
| 254 |
+
},
|
| 255 |
+
"example_response": {
|
| 256 |
+
"response": "Dựa trên thông tin trong database, dao được phân loại là vũ khí nguy hiểm. Dao sắc có thể gây thương tích nghiêm trọng nếu không sử dụng đúng cách. Cần tuân thủ các quy định an toàn khi sử dụng.",
|
| 257 |
+
"context_used": [
|
| 258 |
+
{
|
| 259 |
+
"id": "68a3fc14c853d7621e8977b5",
|
| 260 |
+
"confidence": 0.92,
|
| 261 |
+
"metadata": {
|
| 262 |
+
"text": "Vũ khí"
|
| 263 |
+
}
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"id": "68a3fc4cc853d7621e8977b6",
|
| 267 |
+
"confidence": 0.85,
|
| 268 |
+
"metadata": {
|
| 269 |
+
"text": "Con dao sắc"
|
| 270 |
+
}
|
| 271 |
+
}
|
| 272 |
+
],
|
| 273 |
+
"timestamp": "2025-10-13T10:30:45.123456"
|
| 274 |
+
},
|
| 275 |
+
"notes": [
|
| 276 |
+
"RAG retrieves relevant context from vector DB before generating response",
|
| 277 |
+
"LLM uses context to provide accurate, grounded answers",
|
| 278 |
+
"Requires HUGGINGFACE_TOKEN environment variable or hf_token in request"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
"POST /documents": {
|
| 282 |
+
"description": "Add document to knowledge base for RAG",
|
| 283 |
+
"request": {
|
| 284 |
+
"method": "POST",
|
| 285 |
+
"content_type": "application/json",
|
| 286 |
+
"body": {
|
| 287 |
+
"text": "string (required) - Document text content",
|
| 288 |
+
"metadata": "object (optional) - Additional metadata (source, category, etc.)"
|
| 289 |
+
}
|
| 290 |
+
},
|
| 291 |
+
"response": {
|
| 292 |
+
"success": "boolean",
|
| 293 |
+
"doc_id": "string - MongoDB ObjectId",
|
| 294 |
+
"message": "string - Status message"
|
| 295 |
+
},
|
| 296 |
+
"example_request": {
|
| 297 |
+
"text": "Để tạo event mới: Click nút 'Tạo Event' ở góc trên bên phải màn hình. Điền thông tin sự kiện bao gồm tên, ngày giờ, địa điểm. Click Lưu để hoàn tất.",
|
| 298 |
+
"metadata": {
|
| 299 |
+
"source": "user_guide.pdf",
|
| 300 |
+
"section": "create_event",
|
| 301 |
+
"page": 5,
|
| 302 |
+
"category": "tutorial"
|
| 303 |
+
}
|
| 304 |
+
},
|
| 305 |
+
"example_response": {
|
| 306 |
+
"success": True,
|
| 307 |
+
"doc_id": "67a9876543210fedcba98765",
|
| 308 |
+
"message": "Document added successfully with ID: 67a9876543210fedcba98765"
|
| 309 |
+
}
|
| 310 |
+
},
|
| 311 |
+
"POST /rag/search": {
|
| 312 |
+
"description": "Search in knowledge base (similar to /search/text but for RAG documents)",
|
| 313 |
+
"request": {
|
| 314 |
+
"method": "POST",
|
| 315 |
+
"content_type": "multipart/form-data",
|
| 316 |
+
"body": {
|
| 317 |
+
"query": "string (required) - Search query",
|
| 318 |
+
"top_k": "integer (optional, default: 5) - Number of results",
|
| 319 |
+
"score_threshold": "float (optional, default: 0.5) - Minimum relevance score"
|
| 320 |
+
}
|
| 321 |
+
},
|
| 322 |
+
"response": [
|
| 323 |
+
{
|
| 324 |
+
"id": "string",
|
| 325 |
+
"confidence": "float",
|
| 326 |
+
"metadata": {
|
| 327 |
+
"text": "string",
|
| 328 |
+
"source": "string"
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
],
|
| 332 |
+
"example_request": {
|
| 333 |
+
"query": "cách tạo sự kiện mới",
|
| 334 |
+
"top_k": 3,
|
| 335 |
+
"score_threshold": 0.6
|
| 336 |
+
}
|
| 337 |
+
},
|
| 338 |
+
"GET /history": {
|
| 339 |
+
"description": "Get chat conversation history",
|
| 340 |
+
"request": {
|
| 341 |
+
"method": "GET",
|
| 342 |
+
"query_params": {
|
| 343 |
+
"limit": "integer (optional, default: 10) - Number of messages",
|
| 344 |
+
"skip": "integer (optional, default: 0) - Pagination offset"
|
| 345 |
+
}
|
| 346 |
+
},
|
| 347 |
+
"response": {
|
| 348 |
+
"history": [
|
| 349 |
+
{
|
| 350 |
+
"user_message": "string",
|
| 351 |
+
"assistant_response": "string",
|
| 352 |
+
"context_used": "array",
|
| 353 |
+
"timestamp": "string - ISO 8601"
|
| 354 |
+
}
|
| 355 |
+
],
|
| 356 |
+
"total": "integer - Total messages count"
|
| 357 |
+
},
|
| 358 |
+
"example_request": "GET /history?limit=5&skip=0",
|
| 359 |
+
"example_response": {
|
| 360 |
+
"history": [
|
| 361 |
+
{
|
| 362 |
+
"user_message": "Dao có nguy hiểm không?",
|
| 363 |
+
"assistant_response": "Dao được phân loại là vũ khí...",
|
| 364 |
+
"context_used": [],
|
| 365 |
+
"timestamp": "2025-10-13T10:30:45.123456"
|
| 366 |
+
}
|
| 367 |
+
],
|
| 368 |
+
"total": 15
|
| 369 |
+
}
|
| 370 |
+
},
|
| 371 |
+
"DELETE /documents/{doc_id}": {
|
| 372 |
+
"description": "Delete document from knowledge base",
|
| 373 |
+
"request": {
|
| 374 |
+
"method": "DELETE",
|
| 375 |
+
"path_params": {
|
| 376 |
+
"doc_id": "string - MongoDB ObjectId"
|
| 377 |
+
}
|
| 378 |
+
},
|
| 379 |
+
"response": {
|
| 380 |
+
"success": "boolean",
|
| 381 |
+
"message": "string"
|
| 382 |
+
}
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
},
|
| 386 |
+
"usage_examples": {
|
| 387 |
+
"curl_chat": "curl -X POST 'http://localhost:8000/chat' -H 'Content-Type: application/json' -d '{\"message\": \"Dao có nguy hiểm không?\", \"use_rag\": true}'",
|
| 388 |
+
"python_chat": """
|
| 389 |
+
import requests
|
| 390 |
+
|
| 391 |
+
response = requests.post(
|
| 392 |
+
'http://localhost:8000/chat',
|
| 393 |
+
json={
|
| 394 |
+
'message': 'Nút tạo event ở đâu?',
|
| 395 |
+
'use_rag': True,
|
| 396 |
+
'top_k': 3
|
| 397 |
+
}
|
| 398 |
+
)
|
| 399 |
+
print(response.json()['response'])
|
| 400 |
+
"""
|
| 401 |
+
},
|
| 402 |
+
"authentication": {
|
| 403 |
+
"embeddings_apis": "No authentication required",
|
| 404 |
+
"chat_api": "Requires HUGGINGFACE_TOKEN (env variable or request body)"
|
| 405 |
+
},
|
| 406 |
+
"rate_limits": {
|
| 407 |
+
"embeddings": "No limit",
|
| 408 |
+
"chat_with_llm": "Limited by Hugging Face API (free tier: ~1000 requests/hour)"
|
| 409 |
+
},
|
| 410 |
+
"error_codes": {
|
| 411 |
+
"400": "Bad Request - Missing required fields or invalid input",
|
| 412 |
+
"401": "Unauthorized - Invalid Hugging Face token",
|
| 413 |
+
"404": "Not Found - Document ID not found",
|
| 414 |
+
"500": "Internal Server Error - Server or database error"
|
| 415 |
+
},
|
| 416 |
+
"links": {
|
| 417 |
+
"docs": "http://localhost:8000/docs",
|
| 418 |
+
"redoc": "http://localhost:8000/redoc",
|
| 419 |
+
"openapi": "http://localhost:8000/openapi.json"
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
@app.post("/index", response_model=IndexResponse)
|
| 424 |
+
async def index_data(
|
| 425 |
+
id: str = Form(...),
|
| 426 |
+
text: str = Form(...),
|
| 427 |
+
image: Optional[UploadFile] = File(None)
|
| 428 |
+
):
|
| 429 |
+
"""
|
| 430 |
+
Index data vào vector database
|
| 431 |
+
|
| 432 |
+
Body:
|
| 433 |
+
- id: Document ID (event ID, post ID, etc.)
|
| 434 |
+
- text: Text content (tiếng Việt supported)
|
| 435 |
+
- image: Image file (optional)
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
- success: True/False
|
| 439 |
+
- id: Document ID
|
| 440 |
+
- message: Status message
|
| 441 |
+
"""
|
| 442 |
+
try:
|
| 443 |
+
# Prepare embeddings
|
| 444 |
+
text_embedding = None
|
| 445 |
+
image_embedding = None
|
| 446 |
+
|
| 447 |
+
# Encode text (tiếng Việt)
|
| 448 |
+
if text and text.strip():
|
| 449 |
+
text_embedding = embedding_service.encode_text(text)
|
| 450 |
+
|
| 451 |
+
# Encode image nếu có
|
| 452 |
+
if image:
|
| 453 |
+
image_bytes = await image.read()
|
| 454 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 455 |
+
image_embedding = embedding_service.encode_image(pil_image)
|
| 456 |
+
|
| 457 |
+
# Combine embeddings
|
| 458 |
+
if text_embedding is not None and image_embedding is not None:
|
| 459 |
+
# Average của text và image embeddings
|
| 460 |
+
combined_embedding = np.mean([text_embedding, image_embedding], axis=0)
|
| 461 |
+
elif text_embedding is not None:
|
| 462 |
+
combined_embedding = text_embedding
|
| 463 |
+
elif image_embedding is not None:
|
| 464 |
+
combined_embedding = image_embedding
|
| 465 |
+
else:
|
| 466 |
+
raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image")
|
| 467 |
+
|
| 468 |
+
# Normalize
|
| 469 |
+
combined_embedding = combined_embedding / np.linalg.norm(combined_embedding, axis=1, keepdims=True)
|
| 470 |
+
|
| 471 |
+
# Index vào Qdrant
|
| 472 |
+
metadata = {
|
| 473 |
+
"text": text,
|
| 474 |
+
"has_image": image is not None,
|
| 475 |
+
"image_filename": image.filename if image else None
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
result = qdrant_service.index_data(
|
| 479 |
+
doc_id=id,
|
| 480 |
+
embedding=combined_embedding,
|
| 481 |
+
metadata=metadata
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
return IndexResponse(
|
| 485 |
+
success=True,
|
| 486 |
+
id=result["original_id"], # Trả về MongoDB ObjectId
|
| 487 |
+
message=f"Đã index thành công document {result['original_id']} (Qdrant UUID: {result['qdrant_id']})"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
except Exception as e:
|
| 491 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi index: {str(e)}")
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
@app.post("/search", response_model=List[SearchResponse])
|
| 495 |
+
async def search(
|
| 496 |
+
text: Optional[str] = Form(None),
|
| 497 |
+
image: Optional[UploadFile] = File(None),
|
| 498 |
+
limit: int = Form(10),
|
| 499 |
+
score_threshold: Optional[float] = Form(None),
|
| 500 |
+
text_weight: float = Form(0.5),
|
| 501 |
+
image_weight: float = Form(0.5)
|
| 502 |
+
):
|
| 503 |
+
"""
|
| 504 |
+
Search similar documents bằng text và/hoặc image
|
| 505 |
+
|
| 506 |
+
Body:
|
| 507 |
+
- text: Query text (tiếng Việt supported)
|
| 508 |
+
- image: Query image (optional)
|
| 509 |
+
- limit: Số lượng kết quả (default: 10)
|
| 510 |
+
- score_threshold: Minimum confidence score (0-1)
|
| 511 |
+
- text_weight: Weight cho text search (default: 0.5)
|
| 512 |
+
- image_weight: Weight cho image search (default: 0.5)
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
- List of results với id, confidence, và metadata
|
| 516 |
+
"""
|
| 517 |
+
try:
|
| 518 |
+
# Prepare query embeddings
|
| 519 |
+
text_embedding = None
|
| 520 |
+
image_embedding = None
|
| 521 |
+
|
| 522 |
+
# Encode text query
|
| 523 |
+
if text and text.strip():
|
| 524 |
+
text_embedding = embedding_service.encode_text(text)
|
| 525 |
+
|
| 526 |
+
# Encode image query
|
| 527 |
+
if image:
|
| 528 |
+
image_bytes = await image.read()
|
| 529 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 530 |
+
image_embedding = embedding_service.encode_image(pil_image)
|
| 531 |
+
|
| 532 |
+
# Validate input
|
| 533 |
+
if text_embedding is None and image_embedding is None:
|
| 534 |
+
raise HTTPException(status_code=400, detail="Phải cung cấp ít nhất text hoặc image để search")
|
| 535 |
+
|
| 536 |
+
# Hybrid search với Qdrant
|
| 537 |
+
results = qdrant_service.hybrid_search(
|
| 538 |
+
text_embedding=text_embedding,
|
| 539 |
+
image_embedding=image_embedding,
|
| 540 |
+
text_weight=text_weight,
|
| 541 |
+
image_weight=image_weight,
|
| 542 |
+
limit=limit,
|
| 543 |
+
score_threshold=score_threshold,
|
| 544 |
+
ef=256 # High accuracy search
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
# Format response
|
| 548 |
+
return [
|
| 549 |
+
SearchResponse(
|
| 550 |
+
id=result["id"],
|
| 551 |
+
confidence=result["confidence"],
|
| 552 |
+
metadata=result["metadata"]
|
| 553 |
+
)
|
| 554 |
+
for result in results
|
| 555 |
+
]
|
| 556 |
+
|
| 557 |
+
except Exception as e:
|
| 558 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@app.post("/search/text", response_model=List[SearchResponse])
|
| 562 |
+
async def search_by_text(
|
| 563 |
+
text: str = Form(...),
|
| 564 |
+
limit: int = Form(10),
|
| 565 |
+
score_threshold: Optional[float] = Form(None)
|
| 566 |
+
):
|
| 567 |
+
"""
|
| 568 |
+
Search chỉ bằng text (tiếng Việt)
|
| 569 |
+
|
| 570 |
+
Body:
|
| 571 |
+
- text: Query text (tiếng Việt)
|
| 572 |
+
- limit: Số lượng kết quả
|
| 573 |
+
- score_threshold: Minimum confidence score
|
| 574 |
+
|
| 575 |
+
Returns:
|
| 576 |
+
- List of results
|
| 577 |
+
"""
|
| 578 |
+
try:
|
| 579 |
+
# Encode text
|
| 580 |
+
text_embedding = embedding_service.encode_text(text)
|
| 581 |
+
|
| 582 |
+
# Search
|
| 583 |
+
results = qdrant_service.search(
|
| 584 |
+
query_embedding=text_embedding,
|
| 585 |
+
limit=limit,
|
| 586 |
+
score_threshold=score_threshold,
|
| 587 |
+
ef=256
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
return [
|
| 591 |
+
SearchResponse(
|
| 592 |
+
id=result["id"],
|
| 593 |
+
confidence=result["confidence"],
|
| 594 |
+
metadata=result["metadata"]
|
| 595 |
+
)
|
| 596 |
+
for result in results
|
| 597 |
+
]
|
| 598 |
+
|
| 599 |
+
except Exception as e:
|
| 600 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@app.post("/search/image", response_model=List[SearchResponse])
|
| 604 |
+
async def search_by_image(
|
| 605 |
+
image: UploadFile = File(...),
|
| 606 |
+
limit: int = Form(10),
|
| 607 |
+
score_threshold: Optional[float] = Form(None)
|
| 608 |
+
):
|
| 609 |
+
"""
|
| 610 |
+
Search chỉ bằng image
|
| 611 |
+
|
| 612 |
+
Body:
|
| 613 |
+
- image: Query image
|
| 614 |
+
- limit: Số lượng kết quả
|
| 615 |
+
- score_threshold: Minimum confidence score
|
| 616 |
+
|
| 617 |
+
Returns:
|
| 618 |
+
- List of results
|
| 619 |
+
"""
|
| 620 |
+
try:
|
| 621 |
+
# Encode image
|
| 622 |
+
image_bytes = await image.read()
|
| 623 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 624 |
+
image_embedding = embedding_service.encode_image(pil_image)
|
| 625 |
+
|
| 626 |
+
# Search
|
| 627 |
+
results = qdrant_service.search(
|
| 628 |
+
query_embedding=image_embedding,
|
| 629 |
+
limit=limit,
|
| 630 |
+
score_threshold=score_threshold,
|
| 631 |
+
ef=256
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
return [
|
| 635 |
+
SearchResponse(
|
| 636 |
+
id=result["id"],
|
| 637 |
+
confidence=result["confidence"],
|
| 638 |
+
metadata=result["metadata"]
|
| 639 |
+
)
|
| 640 |
+
for result in results
|
| 641 |
+
]
|
| 642 |
+
|
| 643 |
+
except Exception as e:
|
| 644 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi search: {str(e)}")
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
@app.delete("/delete/{doc_id}")
|
| 648 |
+
async def delete_document(doc_id: str):
|
| 649 |
+
"""
|
| 650 |
+
Delete document by ID (MongoDB ObjectId hoặc UUID)
|
| 651 |
+
|
| 652 |
+
Args:
|
| 653 |
+
- doc_id: Document ID to delete
|
| 654 |
+
|
| 655 |
+
Returns:
|
| 656 |
+
- Success message
|
| 657 |
+
"""
|
| 658 |
+
try:
|
| 659 |
+
qdrant_service.delete_by_id(doc_id)
|
| 660 |
+
return {"success": True, "message": f"Đã xóa document {doc_id}"}
|
| 661 |
+
except Exception as e:
|
| 662 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi xóa: {str(e)}")
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
@app.get("/document/{doc_id}")
|
| 666 |
+
async def get_document(doc_id: str):
|
| 667 |
+
"""
|
| 668 |
+
Get document by ID (MongoDB ObjectId hoặc UUID)
|
| 669 |
+
|
| 670 |
+
Args:
|
| 671 |
+
- doc_id: Document ID (MongoDB ObjectId)
|
| 672 |
+
|
| 673 |
+
Returns:
|
| 674 |
+
- Document data
|
| 675 |
+
"""
|
| 676 |
+
try:
|
| 677 |
+
doc = qdrant_service.get_by_id(doc_id)
|
| 678 |
+
if doc:
|
| 679 |
+
return {
|
| 680 |
+
"success": True,
|
| 681 |
+
"data": doc
|
| 682 |
+
}
|
| 683 |
+
raise HTTPException(status_code=404, detail=f"Không tìm thấy document {doc_id}")
|
| 684 |
+
except HTTPException:
|
| 685 |
+
raise
|
| 686 |
+
except Exception as e:
|
| 687 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi get document: {str(e)}")
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
@app.get("/stats")
|
| 691 |
+
async def get_stats():
|
| 692 |
+
"""
|
| 693 |
+
Lấy thông tin thống kê collection
|
| 694 |
+
|
| 695 |
+
Returns:
|
| 696 |
+
- Collection statistics
|
| 697 |
+
"""
|
| 698 |
+
try:
|
| 699 |
+
info = qdrant_service.get_collection_info()
|
| 700 |
+
return info
|
| 701 |
+
except Exception as e:
|
| 702 |
+
raise HTTPException(status_code=500, detail=f"Lỗi khi lấy stats: {str(e)}")
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
# ============================================
|
| 706 |
+
# ChatbotRAG Endpoints - DEPRECATED
|
| 707 |
+
# USE /agent/chat INSTEAD
|
| 708 |
+
# ============================================
|
| 709 |
+
# Old endpoints removed - now using Agentic Workflow via /agent/chat
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
@app.get("/chat/history/{session_id}")
|
| 714 |
+
async def get_conversation_history(session_id: str, include_metadata: bool = False):
|
| 715 |
+
"""
|
| 716 |
+
Get conversation history for a session
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
session_id: Session identifier
|
| 720 |
+
include_metadata: Include metadata (rag_stats, tool_calls) in response
|
| 721 |
+
|
| 722 |
+
Returns:
|
| 723 |
+
List of messages with role and content
|
| 724 |
+
|
| 725 |
+
Example:
|
| 726 |
+
```
|
| 727 |
+
GET /chat/history/abc-123?include_metadata=true
|
| 728 |
+
```
|
| 729 |
+
"""
|
| 730 |
+
if not conversation_service.session_exists(session_id):
|
| 731 |
+
raise HTTPException(
|
| 732 |
+
status_code=404,
|
| 733 |
+
detail=f"Session {session_id} not found or has expired"
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
history = conversation_service.get_conversation_history(
|
| 737 |
+
session_id,
|
| 738 |
+
include_metadata=include_metadata
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
session_info = conversation_service.get_session_info(session_id)
|
| 742 |
+
|
| 743 |
+
return {
|
| 744 |
+
"session_id": session_id,
|
| 745 |
+
"message_count": len(history),
|
| 746 |
+
"messages": history,
|
| 747 |
+
"created_at": session_info.get("created_at") if session_info else None,
|
| 748 |
+
"updated_at": session_info.get("updated_at") if session_info else None
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
@app.get("/chat/sessions")
|
| 753 |
+
async def list_sessions(
|
| 754 |
+
limit: int = 50,
|
| 755 |
+
skip: int = 0,
|
| 756 |
+
sort_by: str = "updated_at",
|
| 757 |
+
user_id: Optional[str] = None # NEW: Filter by user
|
| 758 |
+
):
|
| 759 |
+
"""
|
| 760 |
+
List all conversation sessions
|
| 761 |
+
|
| 762 |
+
Query Parameters:
|
| 763 |
+
limit: Maximum sessions to return (default: 50, max: 100)
|
| 764 |
+
skip: Number of sessions to skip for pagination (default: 0)
|
| 765 |
+
sort_by: Field to sort by - 'created_at' or 'updated_at' (default: updated_at)
|
| 766 |
+
user_id: Filter sessions by user_id (optional)
|
| 767 |
+
|
| 768 |
+
Returns:
|
| 769 |
+
List of sessions with metadata and message counts
|
| 770 |
+
|
| 771 |
+
Examples:
|
| 772 |
+
```
|
| 773 |
+
GET /chat/sessions # All sessions
|
| 774 |
+
GET /chat/sessions?user_id=user_123 # Only user_123's sessions
|
| 775 |
+
GET /chat/sessions?limit=20&skip=0&sort_by=updated_at
|
| 776 |
+
```
|
| 777 |
+
"""
|
| 778 |
+
# Validate limit
|
| 779 |
+
if limit > 100:
|
| 780 |
+
limit = 100
|
| 781 |
+
if limit < 1:
|
| 782 |
+
limit = 1
|
| 783 |
+
|
| 784 |
+
# Validate sort_by
|
| 785 |
+
if sort_by not in ["created_at", "updated_at"]:
|
| 786 |
+
raise HTTPException(
|
| 787 |
+
status_code=400,
|
| 788 |
+
detail="sort_by must be 'created_at' or 'updated_at'"
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
sessions = conversation_service.list_sessions(
|
| 792 |
+
limit=limit,
|
| 793 |
+
skip=skip,
|
| 794 |
+
sort_by=sort_by,
|
| 795 |
+
descending=True,
|
| 796 |
+
user_id=user_id # NEW: Pass user_id filter
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
total_sessions = conversation_service.count_sessions(user_id=user_id) # NEW: Count with filter
|
| 800 |
+
|
| 801 |
+
return {
|
| 802 |
+
"total": total_sessions,
|
| 803 |
+
"limit": limit,
|
| 804 |
+
"skip": skip,
|
| 805 |
+
"count": len(sessions),
|
| 806 |
+
"user_id": user_id, # NEW: Include filter in response
|
| 807 |
+
"sessions": sessions
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
@app.get("/scenarios")
|
| 812 |
+
async def list_scenarios():
|
| 813 |
+
"""
|
| 814 |
+
Get list of all available scenarios for proactive chat
|
| 815 |
+
|
| 816 |
+
FE use case:
|
| 817 |
+
- Random pick scenario để bắt đầu chat chủ động
|
| 818 |
+
- Hiển thị menu các scenario available
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
List of scenarios with metadata
|
| 822 |
+
|
| 823 |
+
Example:
|
| 824 |
+
```
|
| 825 |
+
GET /scenarios
|
| 826 |
+
|
| 827 |
+
Response:
|
| 828 |
+
{
|
| 829 |
+
"scenarios": [
|
| 830 |
+
{
|
| 831 |
+
"scenario_id": "price_inquiry",
|
| 832 |
+
"name": "Hỏi giá vé",
|
| 833 |
+
"description": "Tư vấn giá vé và gửi PDF",
|
| 834 |
+
"triggers": ["giá vé", "bao nhiêu"],
|
| 835 |
+
"category": "sales"
|
| 836 |
+
},
|
| 837 |
+
...
|
| 838 |
+
]
|
| 839 |
+
}
|
| 840 |
+
```
|
| 841 |
+
"""
|
| 842 |
+
scenarios_list = []
|
| 843 |
+
|
| 844 |
+
for scenario_id, scenario_data in scenario_engine.scenarios.items():
|
| 845 |
+
scenarios_list.append({
|
| 846 |
+
"scenario_id": scenario_id,
|
| 847 |
+
"name": scenario_data.get("name", scenario_id),
|
| 848 |
+
"description": scenario_data.get("description", ""),
|
| 849 |
+
"triggers": scenario_data.get("triggers", []),
|
| 850 |
+
"category": scenario_data.get("category", "general"),
|
| 851 |
+
"priority": scenario_data.get("priority", "normal"),
|
| 852 |
+
"estimated_duration": scenario_data.get("estimated_duration", "unknown")
|
| 853 |
+
})
|
| 854 |
+
|
| 855 |
+
return {
|
| 856 |
+
"total": len(scenarios_list),
|
| 857 |
+
"scenarios": scenarios_list
|
| 858 |
+
}
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
@app.post("/scenarios/{scenario_id}/start")
|
| 862 |
+
async def start_scenario_proactive(
|
| 863 |
+
scenario_id: str,
|
| 864 |
+
request_body: Optional[Dict] = None
|
| 865 |
+
):
|
| 866 |
+
"""
|
| 867 |
+
Start a scenario proactively with optional initial data
|
| 868 |
+
|
| 869 |
+
Use cases:
|
| 870 |
+
1. FE picks random scenario
|
| 871 |
+
2. BE triggers scenario based on user action (after purchase, exit intent, etc.)
|
| 872 |
+
3. Inject context data (event_name, mood, etc.)
|
| 873 |
+
|
| 874 |
+
Example 1 - Simple start:
|
| 875 |
+
```
|
| 876 |
+
POST /scenarios/price_inquiry/start
|
| 877 |
+
{}
|
| 878 |
+
|
| 879 |
+
Response:
|
| 880 |
+
{
|
| 881 |
+
"session_id": "abc-123",
|
| 882 |
+
"message": "Hello 👋 Bạn muốn xem giá..."
|
| 883 |
+
}
|
| 884 |
+
```
|
| 885 |
+
|
| 886 |
+
Example 2 - With initial data (post-event feedback):
|
| 887 |
+
```
|
| 888 |
+
POST /scenarios/post_event_feedback/start
|
| 889 |
+
{
|
| 890 |
+
"initial_data": {
|
| 891 |
+
"event_name": "Hòa Nhạc Mùa Xuân",
|
| 892 |
+
"event_date": "2024-11-29",
|
| 893 |
+
"event_id": "evt_123"
|
| 894 |
+
},
|
| 895 |
+
"session_id": "existing-session", // optional
|
| 896 |
+
"user_id": "user_456" // optional
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
Response:
|
| 900 |
+
{
|
| 901 |
+
"session_id": "abc-123",
|
| 902 |
+
"message": "Cảm ơn bạn đã tham dự *Hòa Nhạc Mùa Xuân* hôm qua!"
|
| 903 |
+
}
|
| 904 |
+
```
|
| 905 |
+
|
| 906 |
+
Example 3 - Mood recommendation:
|
| 907 |
+
```
|
| 908 |
+
POST /scenarios/mood_recommendation/start
|
| 909 |
+
{
|
| 910 |
+
"initial_data": {
|
| 911 |
+
"mood": "chill",
|
| 912 |
+
"preferred_genre": "acoustic"
|
| 913 |
+
}
|
| 914 |
+
}
|
| 915 |
+
```
|
| 916 |
+
"""
|
| 917 |
+
# Parse request body
|
| 918 |
+
body = request_body or {}
|
| 919 |
+
initial_data = body.get("initial_data", {})
|
| 920 |
+
session_id = body.get("session_id")
|
| 921 |
+
user_id = body.get("user_id")
|
| 922 |
+
|
| 923 |
+
# Create or use existing session
|
| 924 |
+
if not session_id:
|
| 925 |
+
session_id = conversation_service.create_session(
|
| 926 |
+
metadata={"started_by": "proactive", "scenario": scenario_id},
|
| 927 |
+
user_id=user_id
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
# Start scenario with initial data
|
| 931 |
+
result = scenario_engine.start_scenario(scenario_id, initial_data)
|
| 932 |
+
|
| 933 |
+
if result.get("new_state"):
|
| 934 |
+
conversation_service.set_scenario_state(session_id, result["new_state"])
|
| 935 |
+
|
| 936 |
+
# Save bot message to history
|
| 937 |
+
conversation_service.add_message(
|
| 938 |
+
session_id,
|
| 939 |
+
"assistant",
|
| 940 |
+
result["message"],
|
| 941 |
+
metadata={"proactive": True, "scenario": scenario_id, "initial_data": initial_data}
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
return {
|
| 945 |
+
"session_id": session_id,
|
| 946 |
+
"scenario_id": scenario_id,
|
| 947 |
+
"message": result["message"],
|
| 948 |
+
"scenario_active": True,
|
| 949 |
+
"proactive": True
|
| 950 |
+
}
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
@app.post("/chat/clear-session")
|
| 954 |
+
async def clear_chat_session(session_id: str):
|
| 955 |
+
"""
|
| 956 |
+
Clear conversation history for a session
|
| 957 |
+
|
| 958 |
+
Args:
|
| 959 |
+
session_id: Session identifier to clear
|
| 960 |
+
|
| 961 |
+
Returns:
|
| 962 |
+
Success message
|
| 963 |
+
|
| 964 |
+
Example:
|
| 965 |
+
```
|
| 966 |
+
POST /chat/clear-session?session_id=abc-123
|
| 967 |
+
```
|
| 968 |
+
"""
|
| 969 |
+
success = conversation_service.clear_session(session_id)
|
| 970 |
+
|
| 971 |
+
if success:
|
| 972 |
+
return {
|
| 973 |
+
"success": True,
|
| 974 |
+
"message": f"Session {session_id} cleared successfully"
|
| 975 |
+
}
|
| 976 |
+
else:
|
| 977 |
+
raise HTTPException(
|
| 978 |
+
status_code=404,
|
| 979 |
+
detail=f"Session {session_id} not found or already cleared"
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
@app.get("/chat/session/{session_id}")
|
| 984 |
+
async def get_session_info(session_id: str):
|
| 985 |
+
"""
|
| 986 |
+
Get metadata about a conversation session
|
| 987 |
+
|
| 988 |
+
Args:
|
| 989 |
+
session_id: Session identifier
|
| 990 |
+
|
| 991 |
+
Returns:
|
| 992 |
+
Session info including creation time and message count
|
| 993 |
+
|
| 994 |
+
Example:
|
| 995 |
+
```
|
| 996 |
+
GET /chat/session/abc-123
|
| 997 |
+
```
|
| 998 |
+
"""
|
| 999 |
+
session = conversation_service.get_session_info(session_id)
|
| 1000 |
+
|
| 1001 |
+
if not session:
|
| 1002 |
+
raise HTTPException(
|
| 1003 |
+
status_code=404,
|
| 1004 |
+
detail=f"Session {session_id} not found"
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Get message count
|
| 1008 |
+
history = conversation_service.get_conversation_history(
|
| 1009 |
+
session_id,
|
| 1010 |
+
include_metadata=True
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
return {
|
| 1014 |
+
"session_id": session["session_id"],
|
| 1015 |
+
"created_at": session["created_at"],
|
| 1016 |
+
"updated_at": session["updated_at"],
|
| 1017 |
+
"message_count": len(history),
|
| 1018 |
+
"metadata": session.get("metadata", {})
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
@app.post("/documents", response_model=AddDocumentResponse)
|
| 1023 |
+
async def add_document(request: AddDocumentRequest):
|
| 1024 |
+
"""
|
| 1025 |
+
Add document to knowledge base
|
| 1026 |
+
|
| 1027 |
+
Body:
|
| 1028 |
+
- text: Document text
|
| 1029 |
+
- metadata: Additional metadata (optional)
|
| 1030 |
+
|
| 1031 |
+
Returns:
|
| 1032 |
+
- success: True/False
|
| 1033 |
+
- doc_id: MongoDB document ID
|
| 1034 |
+
- message: Status message
|
| 1035 |
+
"""
|
| 1036 |
+
try:
|
| 1037 |
+
# Save to MongoDB
|
| 1038 |
+
doc_data = {
|
| 1039 |
+
"text": request.text,
|
| 1040 |
+
"metadata": request.metadata or {},
|
| 1041 |
+
"created_at": datetime.utcnow()
|
| 1042 |
+
}
|
| 1043 |
+
result = documents_collection.insert_one(doc_data)
|
| 1044 |
+
doc_id = str(result.inserted_id)
|
| 1045 |
+
|
| 1046 |
+
# Generate embedding
|
| 1047 |
+
embedding = embedding_service.encode_text(request.text)
|
| 1048 |
+
|
| 1049 |
+
# Index to Qdrant
|
| 1050 |
+
qdrant_service.index_data(
|
| 1051 |
+
doc_id=doc_id,
|
| 1052 |
+
embedding=embedding,
|
| 1053 |
+
metadata={
|
| 1054 |
+
"text": request.text,
|
| 1055 |
+
"source": "api",
|
| 1056 |
+
**(request.metadata or {})
|
| 1057 |
+
}
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
return AddDocumentResponse(
|
| 1061 |
+
success=True,
|
| 1062 |
+
doc_id=doc_id,
|
| 1063 |
+
message=f"Document added successfully with ID: {doc_id}"
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
except Exception as e:
|
| 1067 |
+
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
@app.post("/documents/upload/pdf")
|
| 1071 |
+
async def upload_pdf(
|
| 1072 |
+
file: UploadFile = File(...),
|
| 1073 |
+
metadata: Optional[str] = Form(None)
|
| 1074 |
+
):
|
| 1075 |
+
"""
|
| 1076 |
+
Upload PDF file and index into knowledge base
|
| 1077 |
+
|
| 1078 |
+
Features:
|
| 1079 |
+
- Extracts text from PDF
|
| 1080 |
+
- Detects image URLs in text/markdown
|
| 1081 |
+
- Chunks content intelligently
|
| 1082 |
+
- Indexes all chunks into Qdrant for RAG
|
| 1083 |
+
|
| 1084 |
+
Args:
|
| 1085 |
+
file: PDF file to upload
|
| 1086 |
+
metadata: Optional JSON string with metadata (title, author, etc.)
|
| 1087 |
+
|
| 1088 |
+
Returns:
|
| 1089 |
+
Success status, document ID, and indexing stats
|
| 1090 |
+
|
| 1091 |
+
Example:
|
| 1092 |
+
```bash
|
| 1093 |
+
curl -X POST http://localhost:8000/documents/upload/pdf \
|
| 1094 |
+
-F "file=@document.pdf" \
|
| 1095 |
+
-F 'metadata={"title": "User Guide", "category": "documentation"}'
|
| 1096 |
+
```
|
| 1097 |
+
"""
|
| 1098 |
+
try:
|
| 1099 |
+
# Validate file type
|
| 1100 |
+
if not file.filename.endswith('.pdf'):
|
| 1101 |
+
raise HTTPException(
|
| 1102 |
+
status_code=400,
|
| 1103 |
+
detail="Only PDF files are supported"
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
# Read file bytes
|
| 1107 |
+
pdf_bytes = await file.read()
|
| 1108 |
+
|
| 1109 |
+
# Parse metadata if provided
|
| 1110 |
+
import json
|
| 1111 |
+
doc_metadata = {}
|
| 1112 |
+
if metadata:
|
| 1113 |
+
try:
|
| 1114 |
+
doc_metadata = json.loads(metadata)
|
| 1115 |
+
except json.JSONDecodeError:
|
| 1116 |
+
raise HTTPException(
|
| 1117 |
+
status_code=400,
|
| 1118 |
+
detail="Invalid metadata JSON format"
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
# Generate unique document ID
|
| 1122 |
+
from bson import ObjectId
|
| 1123 |
+
document_id = str(ObjectId())
|
| 1124 |
+
|
| 1125 |
+
# Add upload timestamp
|
| 1126 |
+
doc_metadata['uploaded_at'] = datetime.utcnow().isoformat()
|
| 1127 |
+
doc_metadata['original_filename'] = file.filename
|
| 1128 |
+
|
| 1129 |
+
# Index PDF using multimodal parser
|
| 1130 |
+
result = multimodal_pdf_indexer.index_pdf_bytes(
|
| 1131 |
+
pdf_bytes=pdf_bytes,
|
| 1132 |
+
document_id=document_id,
|
| 1133 |
+
filename=file.filename,
|
| 1134 |
+
document_metadata=doc_metadata
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
return {
|
| 1138 |
+
"success": True,
|
| 1139 |
+
"document_id": document_id,
|
| 1140 |
+
"filename": file.filename,
|
| 1141 |
+
"chunks_indexed": result['chunks_indexed'],
|
| 1142 |
+
"images_found": result.get('images_found', 0),
|
| 1143 |
+
"message": f"PDF uploaded and indexed: {result['chunks_indexed']} chunks, {result.get('images_found', 0)} image URLs found"
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
except HTTPException:
|
| 1147 |
+
raise
|
| 1148 |
+
except Exception as e:
|
| 1149 |
+
raise HTTPException(
|
| 1150 |
+
status_code=500,
|
| 1151 |
+
detail=f"Error processing PDF: {str(e)}"
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
@app.post("/rag/search", response_model=List[SearchResponse])
|
| 1156 |
+
async def rag_search(
|
| 1157 |
+
query: str = Form(...),
|
| 1158 |
+
top_k: int = Form(5),
|
| 1159 |
+
score_threshold: Optional[float] = Form(0.5)
|
| 1160 |
+
):
|
| 1161 |
+
"""
|
| 1162 |
+
Search in knowledge base
|
| 1163 |
+
|
| 1164 |
+
Body:
|
| 1165 |
+
- query: Search query
|
| 1166 |
+
- top_k: Number of results (default: 5)
|
| 1167 |
+
- score_threshold: Minimum score (default: 0.5)
|
| 1168 |
+
|
| 1169 |
+
Returns:
|
| 1170 |
+
- results: List of matching documents
|
| 1171 |
+
"""
|
| 1172 |
+
try:
|
| 1173 |
+
# Generate query embedding
|
| 1174 |
+
query_embedding = embedding_service.encode_text(query)
|
| 1175 |
+
|
| 1176 |
+
# Search in Qdrant
|
| 1177 |
+
results = qdrant_service.search(
|
| 1178 |
+
query_embedding=query_embedding,
|
| 1179 |
+
limit=top_k,
|
| 1180 |
+
score_threshold=score_threshold
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
return [
|
| 1184 |
+
SearchResponse(
|
| 1185 |
+
id=result["id"],
|
| 1186 |
+
confidence=result["confidence"],
|
| 1187 |
+
metadata=result["metadata"]
|
| 1188 |
+
)
|
| 1189 |
+
for result in results
|
| 1190 |
+
]
|
| 1191 |
+
|
| 1192 |
+
except Exception as e:
|
| 1193 |
+
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
@app.get("/history")
|
| 1197 |
+
async def get_history(limit: int = 10, skip: int = 0):
|
| 1198 |
+
"""
|
| 1199 |
+
Get chat history
|
| 1200 |
+
|
| 1201 |
+
Query params:
|
| 1202 |
+
- limit: Number of messages to return (default: 10)
|
| 1203 |
+
- skip: Number of messages to skip (default: 0)
|
| 1204 |
+
|
| 1205 |
+
Returns:
|
| 1206 |
+
- history: List of chat messages
|
| 1207 |
+
"""
|
| 1208 |
+
try:
|
| 1209 |
+
history = list(
|
| 1210 |
+
chat_history_collection
|
| 1211 |
+
.find({}, {"_id": 0})
|
| 1212 |
+
.sort("timestamp", -1)
|
| 1213 |
+
.skip(skip)
|
| 1214 |
+
.limit(limit)
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
# Convert datetime to string
|
| 1218 |
+
for msg in history:
|
| 1219 |
+
if "timestamp" in msg:
|
| 1220 |
+
msg["timestamp"] = msg["timestamp"].isoformat()
|
| 1221 |
+
|
| 1222 |
+
return {
|
| 1223 |
+
"history": history,
|
| 1224 |
+
"total": chat_history_collection.count_documents({})
|
| 1225 |
+
}
|
| 1226 |
+
|
| 1227 |
+
except Exception as e:
|
| 1228 |
+
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
| 1229 |
+
|
| 1230 |
+
|
| 1231 |
+
@app.delete("/documents/{doc_id}")
|
| 1232 |
+
async def delete_document_from_kb(doc_id: str):
|
| 1233 |
+
"""
|
| 1234 |
+
Delete document from knowledge base
|
| 1235 |
+
|
| 1236 |
+
Args:
|
| 1237 |
+
- doc_id: Document ID (MongoDB ObjectId)
|
| 1238 |
+
|
| 1239 |
+
Returns:
|
| 1240 |
+
- success: True/False
|
| 1241 |
+
- message: Status message
|
| 1242 |
+
"""
|
| 1243 |
+
try:
|
| 1244 |
+
# Delete from MongoDB
|
| 1245 |
+
result = documents_collection.delete_one({"_id": doc_id})
|
| 1246 |
+
|
| 1247 |
+
# Delete from Qdrant
|
| 1248 |
+
if result.deleted_count > 0:
|
| 1249 |
+
qdrant_service.delete_by_id(doc_id)
|
| 1250 |
+
return {"success": True, "message": f"Document {doc_id} deleted from knowledge base"}
|
| 1251 |
+
else:
|
| 1252 |
+
raise HTTPException(status_code=404, detail=f"Document {doc_id} not found")
|
| 1253 |
+
|
| 1254 |
+
except HTTPException:
|
| 1255 |
+
raise
|
| 1256 |
+
except Exception as e:
|
| 1257 |
+
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
+
# ===================================
|
| 1261 |
+
# AGENT CHAT STREAMING ENDPOINT (NEW)
|
| 1262 |
+
# ===================================
|
| 1263 |
+
|
| 1264 |
+
@app.post("/agent/chat")
|
| 1265 |
+
async def agent_chat(request: ChatRequest):
|
| 1266 |
+
"""
|
| 1267 |
+
🤖 **Agentic Chatbot với SSE Streaming**
|
| 1268 |
+
|
| 1269 |
+
**Modes:**
|
| 1270 |
+
- `sales`: Sales Agent - Tư vấn sự kiện, chốt sale
|
| 1271 |
+
- `feedback`: Feedback Agent - CSKH, thu thập đánh giá
|
| 1272 |
+
|
| 1273 |
+
**Features:**
|
| 1274 |
+
- ✅ LLM-driven conversation (no hard-coded scenarios)
|
| 1275 |
+
- ✅ Automatic tool calling (search, get_event_details, save_lead...)
|
| 1276 |
+
- ✅ Real-time SSE streaming
|
| 1277 |
+
- ✅ Purchase history check (for feedback mode)
|
| 1278 |
+
|
| 1279 |
+
**Example:**
|
| 1280 |
+
```
|
| 1281 |
+
POST /agent/chat
|
| 1282 |
+
{
|
| 1283 |
+
"message": "Tìm event cho tôi",
|
| 1284 |
+
"mode": "sales",
|
| 1285 |
+
"user_id": "user_123"
|
| 1286 |
+
}
|
| 1287 |
+
```
|
| 1288 |
+
|
| 1289 |
+
**SSE Stream:**
|
| 1290 |
+
```
|
| 1291 |
+
event: status
|
| 1292 |
+
data: Đang tư vấn...
|
| 1293 |
+
|
| 1294 |
+
event: token
|
| 1295 |
+
data: Hello
|
| 1296 |
+
|
| 1297 |
+
event: token
|
| 1298 |
+
data: 👋
|
| 1299 |
+
|
| 1300 |
+
event: done
|
| 1301 |
+
data: {"session_id": "...", "mode": "sales"}
|
| 1302 |
+
```
|
| 1303 |
+
"""
|
| 1304 |
+
return StreamingResponse(
|
| 1305 |
+
agent_chat_stream(
|
| 1306 |
+
request=request,
|
| 1307 |
+
agent_service=agent_service,
|
| 1308 |
+
conversation_service=conversation_service
|
| 1309 |
+
),
|
| 1310 |
+
media_type="text/event-stream",
|
| 1311 |
+
headers={
|
| 1312 |
+
"Cache-Control": "no-cache",
|
| 1313 |
+
"Connection": "keep-alive",
|
| 1314 |
+
"X-Accel-Buffering": "no"
|
| 1315 |
+
}
|
| 1316 |
+
)
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
if __name__ == "__main__":
|
| 1320 |
+
import uvicorn
|
| 1321 |
+
uvicorn.run(
|
| 1322 |
+
app,
|
| 1323 |
+
host="0.0.0.0",
|
| 1324 |
+
port=8000,
|
| 1325 |
+
log_level="info"
|
| 1326 |
+
)
|
multimodal_pdf_parser.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Multimodal PDF Parser for PDFs with Text + Image URLs
|
| 3 |
+
Extracts text, detects image URLs, and links them together
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pypdfium2 as pdfium
|
| 7 |
+
from typing import List, Dict, Optional, Tuple
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class MultimodalChunk:
|
| 14 |
+
"""Represents a chunk with text and associated images"""
|
| 15 |
+
text: str
|
| 16 |
+
page_number: int
|
| 17 |
+
chunk_index: int
|
| 18 |
+
image_urls: List[str] = field(default_factory=list)
|
| 19 |
+
metadata: Dict = field(default_factory=dict)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MultimodalPDFParser:
|
| 23 |
+
"""
|
| 24 |
+
Enhanced PDF Parser that extracts text and image URLs
|
| 25 |
+
Perfect for user guides with screenshots and visual instructions
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
chunk_size: int = 500,
|
| 31 |
+
chunk_overlap: int = 50,
|
| 32 |
+
min_chunk_size: int = 50,
|
| 33 |
+
extract_images: bool = True
|
| 34 |
+
):
|
| 35 |
+
self.chunk_size = chunk_size
|
| 36 |
+
self.chunk_overlap = chunk_overlap
|
| 37 |
+
self.min_chunk_size = min_chunk_size
|
| 38 |
+
self.extract_images = extract_images
|
| 39 |
+
|
| 40 |
+
# URL patterns
|
| 41 |
+
self.url_patterns = [
|
| 42 |
+
# Standard URLs
|
| 43 |
+
r'https?://[^\s<>"{}|\\^`\[\]]+',
|
| 44 |
+
# Markdown images: 
|
| 45 |
+
r'!\[.*?\]\((https?://[^\s)]+)\)',
|
| 46 |
+
# HTML images: <img src="url">
|
| 47 |
+
r'<img[^>]+src=["\']([^"\']+)["\']',
|
| 48 |
+
# Direct image extensions
|
| 49 |
+
r'https?://[^\s<>"{}|\\^`\[\]]+\.(?:jpg|jpeg|png|gif|bmp|svg|webp)',
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
def extract_image_urls(self, text: str) -> List[str]:
|
| 53 |
+
"""
|
| 54 |
+
Extract all image URLs from text
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
text: Text content
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
List of image URLs found
|
| 61 |
+
"""
|
| 62 |
+
urls = []
|
| 63 |
+
|
| 64 |
+
for pattern in self.url_patterns:
|
| 65 |
+
matches = re.findall(pattern, text, re.IGNORECASE)
|
| 66 |
+
urls.extend(matches)
|
| 67 |
+
|
| 68 |
+
# Remove duplicates while preserving order
|
| 69 |
+
seen = set()
|
| 70 |
+
unique_urls = []
|
| 71 |
+
for url in urls:
|
| 72 |
+
if url not in seen:
|
| 73 |
+
seen.add(url)
|
| 74 |
+
unique_urls.append(url)
|
| 75 |
+
|
| 76 |
+
return unique_urls
|
| 77 |
+
|
| 78 |
+
def extract_text_from_pdf(self, pdf_path: str) -> Dict[int, Tuple[str, List[str]]]:
|
| 79 |
+
"""
|
| 80 |
+
Extract text and image URLs from PDF
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
pdf_path: Path to PDF file
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Dictionary mapping page number to (text, image_urls) tuple
|
| 87 |
+
"""
|
| 88 |
+
pdf_pages = {}
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
pdf = pdfium.PdfDocument(pdf_path)
|
| 92 |
+
|
| 93 |
+
for page_num in range(len(pdf)):
|
| 94 |
+
page = pdf[page_num]
|
| 95 |
+
textpage = page.get_textpage()
|
| 96 |
+
text = textpage.get_text_range()
|
| 97 |
+
|
| 98 |
+
# Clean text
|
| 99 |
+
text = self._clean_text(text)
|
| 100 |
+
|
| 101 |
+
# Extract image URLs if enabled
|
| 102 |
+
image_urls = []
|
| 103 |
+
if self.extract_images:
|
| 104 |
+
image_urls = self.extract_image_urls(text)
|
| 105 |
+
|
| 106 |
+
pdf_pages[page_num + 1] = (text, image_urls)
|
| 107 |
+
|
| 108 |
+
return pdf_pages
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
raise Exception(f"Error reading PDF: {str(e)}")
|
| 112 |
+
|
| 113 |
+
def _clean_text(self, text: str) -> str:
|
| 114 |
+
"""Clean extracted text"""
|
| 115 |
+
# Remove excessive whitespace
|
| 116 |
+
text = re.sub(r'\s+', ' ', text)
|
| 117 |
+
# Remove special characters
|
| 118 |
+
text = text.replace('\x00', '')
|
| 119 |
+
return text.strip()
|
| 120 |
+
|
| 121 |
+
def chunk_text_with_images(
|
| 122 |
+
self,
|
| 123 |
+
text: str,
|
| 124 |
+
image_urls: List[str],
|
| 125 |
+
page_number: int
|
| 126 |
+
) -> List[MultimodalChunk]:
|
| 127 |
+
"""
|
| 128 |
+
Split text into chunks and associate images with relevant chunks
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
text: Text to chunk
|
| 132 |
+
image_urls: Image URLs from the page
|
| 133 |
+
page_number: Page number
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
List of MultimodalChunk objects
|
| 137 |
+
"""
|
| 138 |
+
# Split into words
|
| 139 |
+
words = text.split()
|
| 140 |
+
|
| 141 |
+
if len(words) < self.min_chunk_size:
|
| 142 |
+
if len(words) > 0:
|
| 143 |
+
return [MultimodalChunk(
|
| 144 |
+
text=text,
|
| 145 |
+
page_number=page_number,
|
| 146 |
+
chunk_index=0,
|
| 147 |
+
image_urls=image_urls, # All images go to single chunk
|
| 148 |
+
metadata={'page': page_number, 'chunk': 0}
|
| 149 |
+
)]
|
| 150 |
+
return []
|
| 151 |
+
|
| 152 |
+
chunks = []
|
| 153 |
+
chunk_index = 0
|
| 154 |
+
start = 0
|
| 155 |
+
|
| 156 |
+
# Calculate how to distribute images across chunks
|
| 157 |
+
images_per_chunk = len(image_urls) // max(1, len(words) // self.chunk_size) if image_urls else 0
|
| 158 |
+
image_index = 0
|
| 159 |
+
|
| 160 |
+
while start < len(words):
|
| 161 |
+
end = min(start + self.chunk_size, len(words))
|
| 162 |
+
chunk_words = words[start:end]
|
| 163 |
+
chunk_text = ' '.join(chunk_words)
|
| 164 |
+
|
| 165 |
+
# Assign images to this chunk
|
| 166 |
+
chunk_images = []
|
| 167 |
+
if image_urls:
|
| 168 |
+
# Simple strategy: distribute images evenly
|
| 169 |
+
# or detect if URL appears in chunk text
|
| 170 |
+
for url in image_urls:
|
| 171 |
+
if url in chunk_text:
|
| 172 |
+
chunk_images.append(url)
|
| 173 |
+
|
| 174 |
+
# If no URLs found in text, distribute evenly
|
| 175 |
+
if not chunk_images and image_index < len(image_urls):
|
| 176 |
+
# Assign remaining images to chunks
|
| 177 |
+
num_imgs = min(images_per_chunk + 1, len(image_urls) - image_index)
|
| 178 |
+
chunk_images = image_urls[image_index:image_index + num_imgs]
|
| 179 |
+
image_index += num_imgs
|
| 180 |
+
|
| 181 |
+
chunks.append(MultimodalChunk(
|
| 182 |
+
text=chunk_text,
|
| 183 |
+
page_number=page_number,
|
| 184 |
+
chunk_index=chunk_index,
|
| 185 |
+
image_urls=chunk_images,
|
| 186 |
+
metadata={
|
| 187 |
+
'page': page_number,
|
| 188 |
+
'chunk': chunk_index,
|
| 189 |
+
'start_word': start,
|
| 190 |
+
'end_word': end,
|
| 191 |
+
'has_images': len(chunk_images) > 0,
|
| 192 |
+
'num_images': len(chunk_images)
|
| 193 |
+
}
|
| 194 |
+
))
|
| 195 |
+
|
| 196 |
+
chunk_index += 1
|
| 197 |
+
start = end - self.chunk_overlap
|
| 198 |
+
|
| 199 |
+
if start >= len(words) - self.min_chunk_size:
|
| 200 |
+
break
|
| 201 |
+
|
| 202 |
+
return chunks
|
| 203 |
+
|
| 204 |
+
def parse_pdf(
|
| 205 |
+
self,
|
| 206 |
+
pdf_path: str,
|
| 207 |
+
document_metadata: Optional[Dict] = None
|
| 208 |
+
) -> List[MultimodalChunk]:
|
| 209 |
+
"""
|
| 210 |
+
Parse PDF into multimodal chunks
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
pdf_path: Path to PDF file
|
| 214 |
+
document_metadata: Additional metadata
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
List of MultimodalChunk objects
|
| 218 |
+
"""
|
| 219 |
+
pages_data = self.extract_text_from_pdf(pdf_path)
|
| 220 |
+
|
| 221 |
+
all_chunks = []
|
| 222 |
+
for page_num, (text, image_urls) in pages_data.items():
|
| 223 |
+
chunks = self.chunk_text_with_images(text, image_urls, page_num)
|
| 224 |
+
|
| 225 |
+
# Add document metadata
|
| 226 |
+
if document_metadata:
|
| 227 |
+
for chunk in chunks:
|
| 228 |
+
chunk.metadata.update(document_metadata)
|
| 229 |
+
|
| 230 |
+
all_chunks.extend(chunks)
|
| 231 |
+
|
| 232 |
+
return all_chunks
|
| 233 |
+
|
| 234 |
+
def parse_pdf_bytes(
|
| 235 |
+
self,
|
| 236 |
+
pdf_bytes: bytes,
|
| 237 |
+
document_metadata: Optional[Dict] = None
|
| 238 |
+
) -> List[MultimodalChunk]:
|
| 239 |
+
"""Parse PDF from bytes"""
|
| 240 |
+
import tempfile
|
| 241 |
+
import os
|
| 242 |
+
|
| 243 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
| 244 |
+
tmp.write(pdf_bytes)
|
| 245 |
+
tmp_path = tmp.name
|
| 246 |
+
|
| 247 |
+
try:
|
| 248 |
+
chunks = self.parse_pdf(tmp_path, document_metadata)
|
| 249 |
+
return chunks
|
| 250 |
+
finally:
|
| 251 |
+
if os.path.exists(tmp_path):
|
| 252 |
+
os.unlink(tmp_path)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class MultimodalPDFIndexer:
|
| 256 |
+
"""Index multimodal PDF chunks into RAG system"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, embedding_service, qdrant_service, documents_collection):
|
| 259 |
+
self.embedding_service = embedding_service
|
| 260 |
+
self.qdrant_service = qdrant_service
|
| 261 |
+
self.documents_collection = documents_collection
|
| 262 |
+
self.parser = MultimodalPDFParser()
|
| 263 |
+
|
| 264 |
+
def index_pdf(
|
| 265 |
+
self,
|
| 266 |
+
pdf_path: str,
|
| 267 |
+
document_id: str,
|
| 268 |
+
document_metadata: Optional[Dict] = None
|
| 269 |
+
) -> Dict:
|
| 270 |
+
"""Index PDF with image URLs"""
|
| 271 |
+
chunks = self.parser.parse_pdf(pdf_path, document_metadata)
|
| 272 |
+
|
| 273 |
+
indexed_count = 0
|
| 274 |
+
chunk_ids = []
|
| 275 |
+
total_images = 0
|
| 276 |
+
|
| 277 |
+
for chunk in chunks:
|
| 278 |
+
chunk_id = f"{document_id}_p{chunk.page_number}_c{chunk.chunk_index}"
|
| 279 |
+
|
| 280 |
+
# Generate embedding (text-based)
|
| 281 |
+
embedding = self.embedding_service.encode_text(chunk.text)
|
| 282 |
+
|
| 283 |
+
# Prepare metadata with image URLs
|
| 284 |
+
metadata = {
|
| 285 |
+
'text': chunk.text,
|
| 286 |
+
'document_id': document_id,
|
| 287 |
+
'page': chunk.page_number,
|
| 288 |
+
'chunk_index': chunk.chunk_index,
|
| 289 |
+
'source': 'pdf',
|
| 290 |
+
'has_images': len(chunk.image_urls) > 0,
|
| 291 |
+
'image_urls': chunk.image_urls, # Store image URLs!
|
| 292 |
+
'num_images': len(chunk.image_urls),
|
| 293 |
+
**chunk.metadata
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
# Index to Qdrant
|
| 297 |
+
self.qdrant_service.index_data(
|
| 298 |
+
doc_id=chunk_id,
|
| 299 |
+
embedding=embedding,
|
| 300 |
+
metadata=metadata
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
chunk_ids.append(chunk_id)
|
| 304 |
+
indexed_count += 1
|
| 305 |
+
total_images += len(chunk.image_urls)
|
| 306 |
+
|
| 307 |
+
# Save document info
|
| 308 |
+
doc_info = {
|
| 309 |
+
'document_id': document_id,
|
| 310 |
+
'type': 'multimodal_pdf',
|
| 311 |
+
'file_path': pdf_path,
|
| 312 |
+
'num_chunks': indexed_count,
|
| 313 |
+
'total_images': total_images,
|
| 314 |
+
'chunk_ids': chunk_ids,
|
| 315 |
+
'metadata': document_metadata or {}
|
| 316 |
+
}
|
| 317 |
+
self.documents_collection.insert_one(doc_info)
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
'success': True,
|
| 321 |
+
'document_id': document_id,
|
| 322 |
+
'chunks_indexed': indexed_count,
|
| 323 |
+
'images_found': total_images,
|
| 324 |
+
'chunk_ids': chunk_ids[:5]
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
def index_pdf_bytes(
|
| 328 |
+
self,
|
| 329 |
+
pdf_bytes: bytes,
|
| 330 |
+
document_id: str,
|
| 331 |
+
filename: str,
|
| 332 |
+
document_metadata: Optional[Dict] = None
|
| 333 |
+
) -> Dict:
|
| 334 |
+
"""Index PDF from bytes"""
|
| 335 |
+
metadata = document_metadata or {}
|
| 336 |
+
metadata['filename'] = filename
|
| 337 |
+
|
| 338 |
+
chunks = self.parser.parse_pdf_bytes(pdf_bytes, metadata)
|
| 339 |
+
|
| 340 |
+
indexed_count = 0
|
| 341 |
+
chunk_ids = []
|
| 342 |
+
total_images = 0
|
| 343 |
+
|
| 344 |
+
for chunk in chunks:
|
| 345 |
+
chunk_id = f"{document_id}_p{chunk.page_number}_c{chunk.chunk_index}"
|
| 346 |
+
|
| 347 |
+
embedding = self.embedding_service.encode_text(chunk.text)
|
| 348 |
+
|
| 349 |
+
metadata = {
|
| 350 |
+
'text': chunk.text,
|
| 351 |
+
'document_id': document_id,
|
| 352 |
+
'page': chunk.page_number,
|
| 353 |
+
'chunk_index': chunk.chunk_index,
|
| 354 |
+
'source': 'multimodal_pdf',
|
| 355 |
+
'filename': filename,
|
| 356 |
+
'has_images': len(chunk.image_urls) > 0,
|
| 357 |
+
'image_urls': chunk.image_urls,
|
| 358 |
+
'num_images': len(chunk.image_urls),
|
| 359 |
+
**chunk.metadata
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
self.qdrant_service.index_data(
|
| 363 |
+
doc_id=chunk_id,
|
| 364 |
+
embedding=embedding,
|
| 365 |
+
metadata=metadata
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
chunk_ids.append(chunk_id)
|
| 369 |
+
indexed_count += 1
|
| 370 |
+
total_images += len(chunk.image_urls)
|
| 371 |
+
|
| 372 |
+
doc_info = {
|
| 373 |
+
'document_id': document_id,
|
| 374 |
+
'type': 'multimodal_pdf',
|
| 375 |
+
'filename': filename,
|
| 376 |
+
'num_chunks': indexed_count,
|
| 377 |
+
'total_images': total_images,
|
| 378 |
+
'chunk_ids': chunk_ids,
|
| 379 |
+
'metadata': metadata
|
| 380 |
+
}
|
| 381 |
+
self.documents_collection.insert_one(doc_info)
|
| 382 |
+
|
| 383 |
+
return {
|
| 384 |
+
'success': True,
|
| 385 |
+
'document_id': document_id,
|
| 386 |
+
'filename': filename,
|
| 387 |
+
'chunks_indexed': indexed_count,
|
| 388 |
+
'images_found': total_images,
|
| 389 |
+
'chunk_ids': chunk_ids[:5]
|
| 390 |
+
}
|
pdf_parser.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PDF Parser Service for RAG Chatbot
|
| 3 |
+
Extracts text from PDF and splits into chunks for indexing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import pypdfium2 as pdfium
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
import re
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PDFChunk:
|
| 14 |
+
"""Represents a chunk of text from PDF"""
|
| 15 |
+
text: str
|
| 16 |
+
page_number: int
|
| 17 |
+
chunk_index: int
|
| 18 |
+
metadata: Dict
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PDFParser:
|
| 22 |
+
"""Parse PDF files and prepare for RAG indexing"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
chunk_size: int = 500, # words per chunk
|
| 27 |
+
chunk_overlap: int = 50, # words overlap between chunks
|
| 28 |
+
min_chunk_size: int = 50 # minimum words in a chunk
|
| 29 |
+
):
|
| 30 |
+
self.chunk_size = chunk_size
|
| 31 |
+
self.chunk_overlap = chunk_overlap
|
| 32 |
+
self.min_chunk_size = min_chunk_size
|
| 33 |
+
|
| 34 |
+
def extract_text_from_pdf(self, pdf_path: str) -> Dict[int, str]:
|
| 35 |
+
"""
|
| 36 |
+
Extract text from PDF file
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pdf_path: Path to PDF file
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Dictionary mapping page number to text content
|
| 43 |
+
"""
|
| 44 |
+
pdf_text = {}
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
pdf = pdfium.PdfDocument(pdf_path)
|
| 48 |
+
|
| 49 |
+
for page_num in range(len(pdf)):
|
| 50 |
+
page = pdf[page_num]
|
| 51 |
+
textpage = page.get_textpage()
|
| 52 |
+
text = textpage.get_text_range()
|
| 53 |
+
|
| 54 |
+
# Clean text
|
| 55 |
+
text = self._clean_text(text)
|
| 56 |
+
pdf_text[page_num + 1] = text # 1-indexed pages
|
| 57 |
+
|
| 58 |
+
return pdf_text
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
raise Exception(f"Error reading PDF: {str(e)}")
|
| 62 |
+
|
| 63 |
+
def _clean_text(self, text: str) -> str:
|
| 64 |
+
"""Clean extracted text"""
|
| 65 |
+
# Remove excessive whitespace
|
| 66 |
+
text = re.sub(r'\s+', ' ', text)
|
| 67 |
+
|
| 68 |
+
# Remove special characters that might cause issues
|
| 69 |
+
text = text.replace('\x00', '')
|
| 70 |
+
|
| 71 |
+
return text.strip()
|
| 72 |
+
|
| 73 |
+
def chunk_text(self, text: str, page_number: int) -> List[PDFChunk]:
|
| 74 |
+
"""
|
| 75 |
+
Split text into overlapping chunks
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
text: Text to chunk
|
| 79 |
+
page_number: Page number this text came from
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
List of PDFChunk objects
|
| 83 |
+
"""
|
| 84 |
+
# Split into words
|
| 85 |
+
words = text.split()
|
| 86 |
+
|
| 87 |
+
if len(words) < self.min_chunk_size:
|
| 88 |
+
# Text too short, return as single chunk
|
| 89 |
+
if len(words) > 0:
|
| 90 |
+
return [PDFChunk(
|
| 91 |
+
text=text,
|
| 92 |
+
page_number=page_number,
|
| 93 |
+
chunk_index=0,
|
| 94 |
+
metadata={'page': page_number, 'chunk': 0}
|
| 95 |
+
)]
|
| 96 |
+
return []
|
| 97 |
+
|
| 98 |
+
chunks = []
|
| 99 |
+
chunk_index = 0
|
| 100 |
+
start = 0
|
| 101 |
+
|
| 102 |
+
while start < len(words):
|
| 103 |
+
# Get chunk
|
| 104 |
+
end = min(start + self.chunk_size, len(words))
|
| 105 |
+
chunk_words = words[start:end]
|
| 106 |
+
chunk_text = ' '.join(chunk_words)
|
| 107 |
+
|
| 108 |
+
chunks.append(PDFChunk(
|
| 109 |
+
text=chunk_text,
|
| 110 |
+
page_number=page_number,
|
| 111 |
+
chunk_index=chunk_index,
|
| 112 |
+
metadata={
|
| 113 |
+
'page': page_number,
|
| 114 |
+
'chunk': chunk_index,
|
| 115 |
+
'start_word': start,
|
| 116 |
+
'end_word': end
|
| 117 |
+
}
|
| 118 |
+
))
|
| 119 |
+
|
| 120 |
+
chunk_index += 1
|
| 121 |
+
|
| 122 |
+
# Move start position with overlap
|
| 123 |
+
start = end - self.chunk_overlap
|
| 124 |
+
|
| 125 |
+
# Avoid infinite loop
|
| 126 |
+
if start >= len(words) - self.min_chunk_size:
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
return chunks
|
| 130 |
+
|
| 131 |
+
def parse_pdf(
|
| 132 |
+
self,
|
| 133 |
+
pdf_path: str,
|
| 134 |
+
document_metadata: Optional[Dict] = None
|
| 135 |
+
) -> List[PDFChunk]:
|
| 136 |
+
"""
|
| 137 |
+
Parse entire PDF into chunks
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
pdf_path: Path to PDF file
|
| 141 |
+
document_metadata: Additional metadata for the document
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of all chunks from the PDF
|
| 145 |
+
"""
|
| 146 |
+
# Extract text from all pages
|
| 147 |
+
pages_text = self.extract_text_from_pdf(pdf_path)
|
| 148 |
+
|
| 149 |
+
# Chunk each page
|
| 150 |
+
all_chunks = []
|
| 151 |
+
for page_num, text in pages_text.items():
|
| 152 |
+
chunks = self.chunk_text(text, page_num)
|
| 153 |
+
|
| 154 |
+
# Add document metadata
|
| 155 |
+
if document_metadata:
|
| 156 |
+
for chunk in chunks:
|
| 157 |
+
chunk.metadata.update(document_metadata)
|
| 158 |
+
|
| 159 |
+
all_chunks.extend(chunks)
|
| 160 |
+
|
| 161 |
+
return all_chunks
|
| 162 |
+
|
| 163 |
+
def parse_pdf_bytes(
|
| 164 |
+
self,
|
| 165 |
+
pdf_bytes: bytes,
|
| 166 |
+
document_metadata: Optional[Dict] = None
|
| 167 |
+
) -> List[PDFChunk]:
|
| 168 |
+
"""
|
| 169 |
+
Parse PDF from bytes (for uploaded files)
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
pdf_bytes: PDF file as bytes
|
| 173 |
+
document_metadata: Additional metadata
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of chunks
|
| 177 |
+
"""
|
| 178 |
+
import tempfile
|
| 179 |
+
import os
|
| 180 |
+
|
| 181 |
+
# Save to temp file
|
| 182 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
| 183 |
+
tmp.write(pdf_bytes)
|
| 184 |
+
tmp_path = tmp.name
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
chunks = self.parse_pdf(tmp_path, document_metadata)
|
| 188 |
+
return chunks
|
| 189 |
+
finally:
|
| 190 |
+
# Clean up temp file
|
| 191 |
+
if os.path.exists(tmp_path):
|
| 192 |
+
os.unlink(tmp_path)
|
| 193 |
+
|
| 194 |
+
def get_pdf_info(self, pdf_path: str) -> Dict:
|
| 195 |
+
"""
|
| 196 |
+
Get basic info about PDF
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
pdf_path: Path to PDF file
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dictionary with PDF information
|
| 203 |
+
"""
|
| 204 |
+
try:
|
| 205 |
+
pdf = pdfium.PdfDocument(pdf_path)
|
| 206 |
+
|
| 207 |
+
info = {
|
| 208 |
+
'num_pages': len(pdf),
|
| 209 |
+
'file_path': pdf_path,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
return info
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
raise Exception(f"Error reading PDF info: {str(e)}")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PDFIndexer:
|
| 219 |
+
"""Index PDF chunks into RAG system"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, embedding_service, qdrant_service, documents_collection):
|
| 222 |
+
self.embedding_service = embedding_service
|
| 223 |
+
self.qdrant_service = qdrant_service
|
| 224 |
+
self.documents_collection = documents_collection
|
| 225 |
+
self.parser = PDFParser()
|
| 226 |
+
|
| 227 |
+
def index_pdf(
|
| 228 |
+
self,
|
| 229 |
+
pdf_path: str,
|
| 230 |
+
document_id: str,
|
| 231 |
+
document_metadata: Optional[Dict] = None
|
| 232 |
+
) -> Dict:
|
| 233 |
+
"""
|
| 234 |
+
Index entire PDF into RAG system
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
pdf_path: Path to PDF file
|
| 238 |
+
document_id: Unique ID for this document
|
| 239 |
+
document_metadata: Additional metadata (title, author, etc.)
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Indexing results
|
| 243 |
+
"""
|
| 244 |
+
# Parse PDF
|
| 245 |
+
chunks = self.parser.parse_pdf(pdf_path, document_metadata)
|
| 246 |
+
|
| 247 |
+
# Index each chunk
|
| 248 |
+
indexed_count = 0
|
| 249 |
+
chunk_ids = []
|
| 250 |
+
|
| 251 |
+
for chunk in chunks:
|
| 252 |
+
# Generate unique ID for chunk
|
| 253 |
+
chunk_id = f"{document_id}_p{chunk.page_number}_c{chunk.chunk_index}"
|
| 254 |
+
|
| 255 |
+
# Generate embedding
|
| 256 |
+
embedding = self.embedding_service.encode_text(chunk.text)
|
| 257 |
+
|
| 258 |
+
# Prepare metadata
|
| 259 |
+
metadata = {
|
| 260 |
+
'text': chunk.text,
|
| 261 |
+
'document_id': document_id,
|
| 262 |
+
'page': chunk.page_number,
|
| 263 |
+
'chunk_index': chunk.chunk_index,
|
| 264 |
+
'source': 'pdf',
|
| 265 |
+
**chunk.metadata
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
# Index to Qdrant
|
| 269 |
+
self.qdrant_service.index_data(
|
| 270 |
+
doc_id=chunk_id,
|
| 271 |
+
embedding=embedding,
|
| 272 |
+
metadata=metadata
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
chunk_ids.append(chunk_id)
|
| 276 |
+
indexed_count += 1
|
| 277 |
+
|
| 278 |
+
# Save document info to MongoDB
|
| 279 |
+
doc_info = {
|
| 280 |
+
'document_id': document_id,
|
| 281 |
+
'type': 'pdf',
|
| 282 |
+
'file_path': pdf_path,
|
| 283 |
+
'num_chunks': indexed_count,
|
| 284 |
+
'chunk_ids': chunk_ids,
|
| 285 |
+
'metadata': document_metadata or {},
|
| 286 |
+
'pdf_info': self.parser.get_pdf_info(pdf_path)
|
| 287 |
+
}
|
| 288 |
+
self.documents_collection.insert_one(doc_info)
|
| 289 |
+
|
| 290 |
+
return {
|
| 291 |
+
'success': True,
|
| 292 |
+
'document_id': document_id,
|
| 293 |
+
'chunks_indexed': indexed_count,
|
| 294 |
+
'chunk_ids': chunk_ids[:5] # Return first 5 as sample
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
def index_pdf_bytes(
|
| 298 |
+
self,
|
| 299 |
+
pdf_bytes: bytes,
|
| 300 |
+
document_id: str,
|
| 301 |
+
filename: str,
|
| 302 |
+
document_metadata: Optional[Dict] = None
|
| 303 |
+
) -> Dict:
|
| 304 |
+
"""
|
| 305 |
+
Index PDF from bytes (for uploaded files)
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
pdf_bytes: PDF file as bytes
|
| 309 |
+
document_id: Unique ID for this document
|
| 310 |
+
filename: Original filename
|
| 311 |
+
document_metadata: Additional metadata
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Indexing results
|
| 315 |
+
"""
|
| 316 |
+
# Parse PDF
|
| 317 |
+
metadata = document_metadata or {}
|
| 318 |
+
metadata['filename'] = filename
|
| 319 |
+
|
| 320 |
+
chunks = self.parser.parse_pdf_bytes(pdf_bytes, metadata)
|
| 321 |
+
|
| 322 |
+
# Index each chunk
|
| 323 |
+
indexed_count = 0
|
| 324 |
+
chunk_ids = []
|
| 325 |
+
|
| 326 |
+
for chunk in chunks:
|
| 327 |
+
# Generate unique ID for chunk
|
| 328 |
+
chunk_id = f"{document_id}_p{chunk.page_number}_c{chunk.chunk_index}"
|
| 329 |
+
|
| 330 |
+
# Generate embedding
|
| 331 |
+
embedding = self.embedding_service.encode_text(chunk.text)
|
| 332 |
+
|
| 333 |
+
# Prepare metadata
|
| 334 |
+
metadata = {
|
| 335 |
+
'text': chunk.text,
|
| 336 |
+
'document_id': document_id,
|
| 337 |
+
'page': chunk.page_number,
|
| 338 |
+
'chunk_index': chunk.chunk_index,
|
| 339 |
+
'source': 'pdf',
|
| 340 |
+
'filename': filename,
|
| 341 |
+
**chunk.metadata
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
# Index to Qdrant
|
| 345 |
+
self.qdrant_service.index_data(
|
| 346 |
+
doc_id=chunk_id,
|
| 347 |
+
embedding=embedding,
|
| 348 |
+
metadata=metadata
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
chunk_ids.append(chunk_id)
|
| 352 |
+
indexed_count += 1
|
| 353 |
+
|
| 354 |
+
# Save document info to MongoDB
|
| 355 |
+
doc_info = {
|
| 356 |
+
'document_id': document_id,
|
| 357 |
+
'type': 'pdf',
|
| 358 |
+
'filename': filename,
|
| 359 |
+
'num_chunks': indexed_count,
|
| 360 |
+
'chunk_ids': chunk_ids,
|
| 361 |
+
'metadata': metadata
|
| 362 |
+
}
|
| 363 |
+
self.documents_collection.insert_one(doc_info)
|
| 364 |
+
|
| 365 |
+
return {
|
| 366 |
+
'success': True,
|
| 367 |
+
'document_id': document_id,
|
| 368 |
+
'filename': filename,
|
| 369 |
+
'chunks_indexed': indexed_count,
|
| 370 |
+
'chunk_ids': chunk_ids[:5]
|
| 371 |
+
}
|
prompts/feedback_agent.txt
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ROLE
|
| 2 |
+
Bạn là chuyên viên Chăm sóc khách hàng (CSKH) của nền tảng bán vé sự kiện.
|
| 3 |
+
Nhiệm vụ của bạn là lắng nghe phản hồi của khách hàng sau khi tham gia sự kiện và hỗ trợ họ.
|
| 4 |
+
|
| 5 |
+
# GOAL
|
| 6 |
+
1. Kiểm tra xem khách hàng đã tham gia sự kiện nào chưa.
|
| 7 |
+
2. Nếu CÓ: Xin đánh giá (feedback), cảm nhận để cải thiện dịch vụ.
|
| 8 |
+
3. Nếu KHÔNG (hoặc đã feedback xong): Giới thiệu các sự kiện mới hấp dẫn (chuyển sang vai trò Sales).
|
| 9 |
+
|
| 10 |
+
# CAPABILITIES (TOOLS)
|
| 11 |
+
1. `get_purchased_events(user_id)`: Kiểm tra lịch sử mua vé/tham gia sự kiện của khách hàng.
|
| 12 |
+
2. `save_feedback(event_id, rating, comment)`: Lưu đánh giá của khách hàng (rating 1-5 sao).
|
| 13 |
+
3. `search_events(...)`: Tìm sự kiện mới (nếu khách muốn đi tiếp).
|
| 14 |
+
|
| 15 |
+
# GUIDELINES
|
| 16 |
+
|
| 17 |
+
## Phase 1: Check History (Luôn thực hiện đầu tiên)
|
| 18 |
+
- Ngay khi bắt đầu hội thoại, hãy gọi `get_purchased_events(user_id)` ngầm (không cần hỏi khách).
|
| 19 |
+
- **Trường hợp A: Khách chưa từng đi sự kiện nào (hoặc API trả về rỗng)**
|
| 20 |
+
- Chuyển ngay sang mode tư vấn: "Chào bạn! Bạn đang tìm kiếm sự kiện gì thú vị cho tuần này không? Bên mình đang có nhiều show hay lắm! 🎉"
|
| 21 |
+
- (Sau đó hành xử như Sales Agent).
|
| 22 |
+
|
| 23 |
+
- **Trường hợp B: Khách ĐÃ đi sự kiện (ví dụ: "Show Hà Anh Tuấn")**
|
| 24 |
+
- Mở đầu bằng lời chào ấm áp: "Chào bạn! Cảm ơn bạn đã tham gia show **Hà Anh Tuấn** vừa rồi. Hy vọng bạn đã có những giây phút tuyệt vời! 🥰"
|
| 25 |
+
- Hỏi thăm cảm nhận: "Bạn thấy không khí hôm đó thế nào? Có điều gì làm bạn chưa hài lòng không?"
|
| 26 |
+
|
| 27 |
+
## Phase 2: Collect Feedback (Nếu khách đã đi)
|
| 28 |
+
- Lắng nghe khách chia sẻ.
|
| 29 |
+
- Nếu khách khen: "Tuyệt quá! Bạn chấm cho sự kiện mấy sao nè? (1-5 sao) ⭐"
|
| 30 |
+
- Nếu khách chê: Tỏ ra đồng cảm, xin lỗi và hứa cải thiện. "Dạ mình rất tiếc về trải nghiệm này. Mình sẽ ghi nhận ngay để BTC rút kinh nghiệm ạ."
|
| 31 |
+
- Sau khi khách chấm điểm/comment -> Gọi `save_feedback`.
|
| 32 |
+
|
| 33 |
+
## Phase 3: Transition to Sales (Sau khi feedback xong)
|
| 34 |
+
- Sau khi đã lưu feedback, hãy khéo léo giới thiệu sự kiện mới:
|
| 35 |
+
"Cảm ơn bạn đã góp ý nha! À, sắp tới bên mình có show **Mỹ Tâm** cũng vibe tương tự, bạn có muốn xem qua không?"
|
| 36 |
+
- Nếu khách quan tâm -> Dùng `search_events` và tư vấn tiếp.
|
| 37 |
+
|
| 38 |
+
# EXAMPLES
|
| 39 |
+
|
| 40 |
+
**Case 1: Có lịch sử đi event**
|
| 41 |
+
System: (User ID 123 -> get_purchased_events -> ["Show Rock Việt"])
|
| 42 |
+
Agent: "Chào bạn! Cảm ơn bạn đã cháy hết mình tại **Show Rock Việt** hôm qua! 🤘 Bạn thấy ban nhạc diễn có sung không?"
|
| 43 |
+
User: "Sung lắm, nhưng âm thanh hơi rè."
|
| 44 |
+
Agent: "Dạ mình ghi nhận góp ý về âm thanh ạ. Cảm ơn bạn nhiều. Bạn chấm show này mấy điểm trên thang 5 sao nè?"
|
| 45 |
+
User: "4 sao thôi."
|
| 46 |
+
Agent (Call Tool): save_feedback(event_id="rock_viet", rating=4, comment="Sung nhưng âm thanh rè")
|
| 47 |
+
Agent: "Dạ mình đã lưu lại rồi ạ. À sắp tới có **RockStorm** âm thanh xịn hơn, bạn có hóng không? 🔥"
|
| 48 |
+
|
| 49 |
+
**Case 2: Không có lịch sử**
|
| 50 |
+
System: (User ID 456 -> get_purchased_events -> [])
|
| 51 |
+
Agent: "Chào bạn! 👋 Cuối tuần này bạn đã có kế hoạch đi đâu chơi chưa? Bên mình đang có mấy show Acoustic chill lắm nè!"
|
prompts/sales_agent.txt
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ROLE
|
| 2 |
+
Bạn là một chuyên viên tư vấn sự kiện (Sales Agent) nhiệt tình, am hiểu và khéo léo của nền tảng bán vé sự kiện.
|
| 3 |
+
Tên bạn là: "TicketBot" (hoặc xưng là "mình"/"tớ").
|
| 4 |
+
|
| 5 |
+
# GOAL
|
| 6 |
+
Mục tiêu của bạn là giúp khách hàng tìm được sự kiện phù hợp nhất và khuyến khích họ mua vé (hoặc để lại thông tin liên hệ).
|
| 7 |
+
|
| 8 |
+
# CAPABILITIES (TOOLS)
|
| 9 |
+
Bạn có quyền truy cập các công cụ sau (hãy sử dụng chúng khi cần thiết):
|
| 10 |
+
1. `search_events(query, vibe, date)`: Tìm kiếm sự kiện theo từ khóa, tâm trạng (chill, sôi động...), hoặc thời gian.
|
| 11 |
+
2. `get_event_details(event_id)`: Lấy thông tin chi tiết (giá vé, địa điểm, nghệ sĩ, thời gian) của một sự kiện cụ thể.
|
| 12 |
+
3. `save_lead(email, phone, interest)`: Lưu thông tin khách hàng khi họ quan tâm hoặc muốn nhận tư vấn thêm.
|
| 13 |
+
|
| 14 |
+
# GUIDELINES
|
| 15 |
+
1. **Khơi gợi nhu cầu (Consultative Selling):**
|
| 16 |
+
- Đừng chỉ hỏi "Bạn muốn gì?". Hãy hỏi mở: "Cuối tuần này bạn rảnh không? Bạn đang mood muốn 'quẩy' hay chill nhẹ nhàng?"
|
| 17 |
+
- Nếu khách chưa rõ, hãy gợi ý dựa trên các vibe phổ biến: Hài kịch, Nhạc Indie, Workshop, EDM...
|
| 18 |
+
|
| 19 |
+
2. **Tư vấn thông minh:**
|
| 20 |
+
- Khi khách hỏi giá, đừng chỉ đưa con số. Hãy kèm giá trị: "Vé hạng A giá 500k nhưng view siêu đẹp, còn hạng B 300k thì tiết kiệm hơn."
|
| 21 |
+
- Luôn đề xuất thêm (Upsell/Cross-sell) nếu phù hợp: "Đi nhóm 4 người đang có combo giảm 10% đó ạ."
|
| 22 |
+
|
| 23 |
+
3. **Sử dụng Tools khéo léo:**
|
| 24 |
+
- Khi khách hỏi "có sự kiện gì?", HÃY gọi `search_events`. Đừng tự bịa ra sự kiện.
|
| 25 |
+
- Khi trả về danh sách sự kiện, hãy tóm tắt ngắn gọn điểm hấp dẫn nhất của từng cái.
|
| 26 |
+
|
| 27 |
+
4. **Chốt Deal (Closing):**
|
| 28 |
+
- Khi khách có vẻ ưng ý (hỏi chi tiết, giá, chỗ ngồi...), hãy khéo léo xin thông tin:
|
| 29 |
+
"Sự kiện này đang hot lắm, bạn cho mình xin email để mình gửi link đặt vé giữ chỗ ngay nhé?"
|
| 30 |
+
- Hoặc: "Mình gửi lịch diễn chi tiết qua Zalo/Email cho bạn tiện xem nha?" -> Gọi `save_lead`.
|
| 31 |
+
|
| 32 |
+
5. **Tone & Voice:**
|
| 33 |
+
- Thân thiện, trẻ trung, dùng emoji tự nhiên (😄, 🎉, 🔥).
|
| 34 |
+
- Không quá cứng nhắc như robot.
|
| 35 |
+
- Nếu khách hỏi ngoài lề (off-topic), hãy trả lời ngắn gọn rồi khéo léo lái về chủ đề sự kiện.
|
| 36 |
+
|
| 37 |
+
# EXAMPLES
|
| 38 |
+
|
| 39 |
+
User: "Cuối tuần này có gì chơi không?"
|
| 40 |
+
Agent (Thought): Khách chưa nói rõ sở thích. Cần hỏi thêm vibe.
|
| 41 |
+
Agent: "Cuối tuần này Sài Gòn nhiều show hay lắm! Bạn đang mood muốn 'quẩy' hết mình hay tìm một góc chill chill nghe nhạc? 🎶"
|
| 42 |
+
|
| 43 |
+
User: "Chill thôi, nghe nhạc acoustic."
|
| 44 |
+
Agent (Thought): Gọi tool search_events(vibe="chill", category="acoustic").
|
| 45 |
+
Agent (Call Tool): search_events(vibe="chill", category="acoustic")
|
| 46 |
+
... (Tool returns events) ...
|
| 47 |
+
Agent: "À, vậy thì **Mây Lang Thang** hôm thứ 7 này là chuẩn bài! Có Lê Hiếu hát, không gian cực lãng mạn. Hoặc **Lululola** thì view hoàng hôn đỉnh chóp. Bạn thích giọng ai hơn? 🎤"
|
qdrant_service.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient
|
| 2 |
+
from qdrant_client.models import (
|
| 3 |
+
Distance, VectorParams, PointStruct,
|
| 4 |
+
SearchRequest, SearchParams, HnswConfigDiff,
|
| 5 |
+
OptimizersConfigDiff, ScalarQuantization,
|
| 6 |
+
ScalarQuantizationConfig, ScalarType,
|
| 7 |
+
QuantizationSearchParams
|
| 8 |
+
)
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
import numpy as np
|
| 11 |
+
import uuid
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QdrantVectorService:
|
| 16 |
+
"""
|
| 17 |
+
Qdrant Cloud Vector Database Service với cấu hình tối ưu
|
| 18 |
+
- HNSW algorithm với parameters mạnh mẽ nhất
|
| 19 |
+
- Scalar Quantization để tối ưu memory và speed
|
| 20 |
+
- Hỗ trợ hybrid search (text + image)
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
url: Optional[str] = None,
|
| 26 |
+
api_key: Optional[str] = None,
|
| 27 |
+
collection_name: str = "event_social_media",
|
| 28 |
+
vector_size: int = 1024, # Jina CLIP v2 dimension
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize Qdrant Cloud client
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
url: Qdrant Cloud URL (từ env hoặc truyền vào)
|
| 35 |
+
api_key: Qdrant API key (từ env hoặc truyền vào)
|
| 36 |
+
collection_name: Tên collection
|
| 37 |
+
vector_size: Dimension của vectors (1024 cho Jina CLIP v2)
|
| 38 |
+
"""
|
| 39 |
+
# Lấy credentials từ env nếu không truyền vào
|
| 40 |
+
self.url = url or os.getenv("QDRANT_URL")
|
| 41 |
+
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
|
| 42 |
+
|
| 43 |
+
if not self.url or not self.api_key:
|
| 44 |
+
raise ValueError("Cần cung cấp QDRANT_URL và QDRANT_API_KEY (qua env hoặc params)")
|
| 45 |
+
|
| 46 |
+
print(f"Connecting to Qdrant Cloud...")
|
| 47 |
+
|
| 48 |
+
# Initialize Qdrant Cloud client
|
| 49 |
+
self.client = QdrantClient(
|
| 50 |
+
url=self.url,
|
| 51 |
+
api_key=self.api_key,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.collection_name = collection_name
|
| 55 |
+
self.vector_size = vector_size
|
| 56 |
+
|
| 57 |
+
# Create collection nếu chưa tồn tại
|
| 58 |
+
self._ensure_collection()
|
| 59 |
+
|
| 60 |
+
print(f"✓ Connected to Qdrant collection: {collection_name}")
|
| 61 |
+
|
| 62 |
+
def _ensure_collection(self):
|
| 63 |
+
"""
|
| 64 |
+
Tạo collection với HNSW config tối ưu nhất
|
| 65 |
+
"""
|
| 66 |
+
# Check nếu collection đã tồn tại
|
| 67 |
+
collections = self.client.get_collections().collections
|
| 68 |
+
collection_exists = any(c.name == self.collection_name for c in collections)
|
| 69 |
+
|
| 70 |
+
if not collection_exists:
|
| 71 |
+
print(f"Creating collection {self.collection_name} with optimal HNSW config...")
|
| 72 |
+
|
| 73 |
+
self.client.create_collection(
|
| 74 |
+
collection_name=self.collection_name,
|
| 75 |
+
vectors_config=VectorParams(
|
| 76 |
+
size=self.vector_size,
|
| 77 |
+
distance=Distance.COSINE, # Cosine similarity cho embeddings
|
| 78 |
+
hnsw_config=HnswConfigDiff(
|
| 79 |
+
m=64, # Số edges per node - cao nhất cho accuracy
|
| 80 |
+
ef_construct=512, # Search range khi build index - cao cho quality
|
| 81 |
+
full_scan_threshold=10000, # Threshold để switch sang full scan
|
| 82 |
+
max_indexing_threads=0, # Auto-detect số threads
|
| 83 |
+
on_disk=False, # Keep trong RAM cho speed (nếu đủ memory)
|
| 84 |
+
)
|
| 85 |
+
),
|
| 86 |
+
optimizers_config=OptimizersConfigDiff(
|
| 87 |
+
deleted_threshold=0.2,
|
| 88 |
+
vacuum_min_vector_number=1000,
|
| 89 |
+
default_segment_number=2,
|
| 90 |
+
max_segment_size=200000,
|
| 91 |
+
memmap_threshold=50000,
|
| 92 |
+
indexing_threshold=10000,
|
| 93 |
+
flush_interval_sec=5,
|
| 94 |
+
max_optimization_threads=0, # Auto-detect
|
| 95 |
+
),
|
| 96 |
+
# Sử dụng Scalar Quantization để tối ưu memory và speed
|
| 97 |
+
quantization_config=ScalarQuantization(
|
| 98 |
+
scalar=ScalarQuantizationConfig(
|
| 99 |
+
type=ScalarType.INT8,
|
| 100 |
+
quantile=0.99,
|
| 101 |
+
always_ram=True, # Keep quantized vectors trong RAM
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
print("✓ Collection created with optimal configuration")
|
| 106 |
+
else:
|
| 107 |
+
print("✓ Collection already exists")
|
| 108 |
+
|
| 109 |
+
def _convert_to_valid_id(self, doc_id: str) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Convert bất kỳ string ID nào thành UUID hợp lệ cho Qdrant
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
doc_id: Original ID (có thể là MongoDB ObjectId, string, etc.)
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
UUID string hợp lệ
|
| 118 |
+
"""
|
| 119 |
+
if not doc_id:
|
| 120 |
+
return str(uuid.uuid4())
|
| 121 |
+
|
| 122 |
+
# Nếu đã là UUID hợp lệ, giữ nguyên
|
| 123 |
+
try:
|
| 124 |
+
uuid.UUID(doc_id)
|
| 125 |
+
return doc_id
|
| 126 |
+
except ValueError:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
# Convert string sang UUID deterministic (cùng input = cùng UUID)
|
| 130 |
+
# Sử dụng UUID v5 với namespace DNS
|
| 131 |
+
return str(uuid.uuid5(uuid.NAMESPACE_DNS, doc_id))
|
| 132 |
+
|
| 133 |
+
def index_data(
|
| 134 |
+
self,
|
| 135 |
+
doc_id: str,
|
| 136 |
+
embedding: np.ndarray,
|
| 137 |
+
metadata: Dict[str, Any]
|
| 138 |
+
) -> Dict[str, str]:
|
| 139 |
+
"""
|
| 140 |
+
Index data vào Qdrant
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
doc_id: ID của document (MongoDB ObjectId, string, etc.)
|
| 144 |
+
embedding: Vector embedding từ Jina CLIP
|
| 145 |
+
metadata: Metadata (text, image_url, event_info, etc.)
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Dict với original_id và qdrant_id
|
| 149 |
+
"""
|
| 150 |
+
# Convert ID thành UUID hợp lệ
|
| 151 |
+
qdrant_id = self._convert_to_valid_id(doc_id)
|
| 152 |
+
|
| 153 |
+
# Lưu original ID vào metadata
|
| 154 |
+
metadata['original_id'] = doc_id
|
| 155 |
+
|
| 156 |
+
# Ensure embedding là 1D array
|
| 157 |
+
if len(embedding.shape) > 1:
|
| 158 |
+
embedding = embedding.flatten()
|
| 159 |
+
|
| 160 |
+
# Create point
|
| 161 |
+
point = PointStruct(
|
| 162 |
+
id=qdrant_id,
|
| 163 |
+
vector=embedding.tolist(),
|
| 164 |
+
payload=metadata
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Upsert vào collection
|
| 168 |
+
self.client.upsert(
|
| 169 |
+
collection_name=self.collection_name,
|
| 170 |
+
points=[point]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"original_id": doc_id,
|
| 175 |
+
"qdrant_id": qdrant_id
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
def batch_index(
|
| 179 |
+
self,
|
| 180 |
+
doc_ids: List[str],
|
| 181 |
+
embeddings: np.ndarray,
|
| 182 |
+
metadata_list: List[Dict[str, Any]]
|
| 183 |
+
) -> List[Dict[str, str]]:
|
| 184 |
+
"""
|
| 185 |
+
Batch index nhiều documents cùng lúc
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
doc_ids: List of document IDs (MongoDB ObjectId, string, etc.)
|
| 189 |
+
embeddings: Numpy array of embeddings (n_samples, embedding_dim)
|
| 190 |
+
metadata_list: List of metadata dicts
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
List of dicts với original_id và qdrant_id
|
| 194 |
+
"""
|
| 195 |
+
points = []
|
| 196 |
+
id_mappings = []
|
| 197 |
+
|
| 198 |
+
for i, (doc_id, embedding, metadata) in enumerate(zip(doc_ids, embeddings, metadata_list)):
|
| 199 |
+
# Convert to valid UUID
|
| 200 |
+
qdrant_id = self._convert_to_valid_id(doc_id)
|
| 201 |
+
|
| 202 |
+
# Lưu original ID vào metadata
|
| 203 |
+
metadata['original_id'] = doc_id
|
| 204 |
+
|
| 205 |
+
# Ensure embedding là 1D
|
| 206 |
+
if len(embedding.shape) > 1:
|
| 207 |
+
embedding = embedding.flatten()
|
| 208 |
+
|
| 209 |
+
points.append(PointStruct(
|
| 210 |
+
id=qdrant_id,
|
| 211 |
+
vector=embedding.tolist(),
|
| 212 |
+
payload=metadata
|
| 213 |
+
))
|
| 214 |
+
|
| 215 |
+
id_mappings.append({
|
| 216 |
+
"original_id": doc_id,
|
| 217 |
+
"qdrant_id": qdrant_id
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
# Batch upsert
|
| 221 |
+
self.client.upsert(
|
| 222 |
+
collection_name=self.collection_name,
|
| 223 |
+
points=points,
|
| 224 |
+
wait=True # Wait for indexing to complete
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return id_mappings
|
| 228 |
+
|
| 229 |
+
def search(
|
| 230 |
+
self,
|
| 231 |
+
query_embedding: np.ndarray,
|
| 232 |
+
limit: int = 10,
|
| 233 |
+
score_threshold: Optional[float] = None,
|
| 234 |
+
filter_conditions: Optional[Dict] = None,
|
| 235 |
+
ef: int = 256 # Search quality parameter - cao hơn = accurate hơn
|
| 236 |
+
) -> List[Dict[str, Any]]:
|
| 237 |
+
"""
|
| 238 |
+
Search similar vectors trong Qdrant
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
query_embedding: Query embedding từ Jina CLIP
|
| 242 |
+
limit: Số lượng results trả về
|
| 243 |
+
score_threshold: Minimum similarity score (0-1)
|
| 244 |
+
filter_conditions: Qdrant filter conditions
|
| 245 |
+
ef: HNSW search parameter (128-512, cao hơn = accurate hơn)
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
List of search results với id, score, và metadata
|
| 249 |
+
"""
|
| 250 |
+
# Ensure query embedding là 1D
|
| 251 |
+
if len(query_embedding.shape) > 1:
|
| 252 |
+
query_embedding = query_embedding.flatten()
|
| 253 |
+
|
| 254 |
+
# Search với HNSW parameters tối ưu (qdrant-client v1.16.0+)
|
| 255 |
+
search_result = self.client.query_points(
|
| 256 |
+
collection_name=self.collection_name,
|
| 257 |
+
query=query_embedding.tolist(),
|
| 258 |
+
limit=limit,
|
| 259 |
+
score_threshold=score_threshold,
|
| 260 |
+
query_filter=filter_conditions,
|
| 261 |
+
search_params=SearchParams(
|
| 262 |
+
hnsw_ef=ef, # Higher ef = more accurate search
|
| 263 |
+
exact=False, # Use HNSW (not exact search)
|
| 264 |
+
quantization=QuantizationSearchParams(
|
| 265 |
+
ignore=False, # Use quantization
|
| 266 |
+
rescore=True, # Rescore với original vectors
|
| 267 |
+
oversampling=2.0 # Oversample factor
|
| 268 |
+
)
|
| 269 |
+
),
|
| 270 |
+
with_payload=True,
|
| 271 |
+
).points
|
| 272 |
+
|
| 273 |
+
# Format results - trả về original_id thay vì UUID
|
| 274 |
+
results = []
|
| 275 |
+
for hit in search_result:
|
| 276 |
+
# Lấy original_id từ metadata (MongoDB ObjectId)
|
| 277 |
+
original_id = hit.payload.get('original_id', hit.id)
|
| 278 |
+
|
| 279 |
+
results.append({
|
| 280 |
+
"id": original_id, # Trả về MongoDB ObjectId
|
| 281 |
+
"qdrant_id": hit.id, # UUID trong Qdrant
|
| 282 |
+
"confidence": float(hit.score), # Cosine similarity score
|
| 283 |
+
"metadata": hit.payload
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
return results
|
| 287 |
+
|
| 288 |
+
def hybrid_search(
|
| 289 |
+
self,
|
| 290 |
+
text_embedding: Optional[np.ndarray] = None,
|
| 291 |
+
image_embedding: Optional[np.ndarray] = None,
|
| 292 |
+
text_weight: float = 0.5,
|
| 293 |
+
image_weight: float = 0.5,
|
| 294 |
+
limit: int = 10,
|
| 295 |
+
score_threshold: Optional[float] = None,
|
| 296 |
+
ef: int = 256
|
| 297 |
+
) -> List[Dict[str, Any]]:
|
| 298 |
+
"""
|
| 299 |
+
Hybrid search với cả text và image embeddings
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
text_embedding: Text query embedding
|
| 303 |
+
image_embedding: Image query embedding
|
| 304 |
+
text_weight: Weight cho text search (0-1)
|
| 305 |
+
image_weight: Weight cho image search (0-1)
|
| 306 |
+
limit: Số results
|
| 307 |
+
score_threshold: Minimum score
|
| 308 |
+
ef: HNSW search parameter
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Combined search results
|
| 312 |
+
"""
|
| 313 |
+
# Combine embeddings với weights
|
| 314 |
+
combined_embedding = np.zeros(self.vector_size)
|
| 315 |
+
|
| 316 |
+
if text_embedding is not None:
|
| 317 |
+
if len(text_embedding.shape) > 1:
|
| 318 |
+
text_embedding = text_embedding.flatten()
|
| 319 |
+
combined_embedding += text_weight * text_embedding
|
| 320 |
+
|
| 321 |
+
if image_embedding is not None:
|
| 322 |
+
if len(image_embedding.shape) > 1:
|
| 323 |
+
image_embedding = image_embedding.flatten()
|
| 324 |
+
combined_embedding += image_weight * image_embedding
|
| 325 |
+
|
| 326 |
+
# Normalize combined embedding
|
| 327 |
+
norm = np.linalg.norm(combined_embedding)
|
| 328 |
+
if norm > 0:
|
| 329 |
+
combined_embedding = combined_embedding / norm
|
| 330 |
+
|
| 331 |
+
# Search với combined embedding
|
| 332 |
+
return self.search(
|
| 333 |
+
query_embedding=combined_embedding,
|
| 334 |
+
limit=limit,
|
| 335 |
+
score_threshold=score_threshold,
|
| 336 |
+
ef=ef
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def delete_by_id(self, doc_id: str) -> bool:
|
| 340 |
+
"""
|
| 341 |
+
Delete document by ID (hỗ trợ cả MongoDB ObjectId và UUID)
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
doc_id: Document ID to delete (MongoDB ObjectId hoặc UUID)
|
| 345 |
+
|
| 346 |
+
Returns:
|
| 347 |
+
Success status
|
| 348 |
+
"""
|
| 349 |
+
# Convert to UUID nếu là MongoDB ObjectId
|
| 350 |
+
qdrant_id = self._convert_to_valid_id(doc_id)
|
| 351 |
+
|
| 352 |
+
self.client.delete(
|
| 353 |
+
collection_name=self.collection_name,
|
| 354 |
+
points_selector=[qdrant_id]
|
| 355 |
+
)
|
| 356 |
+
return True
|
| 357 |
+
|
| 358 |
+
def get_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
| 359 |
+
"""
|
| 360 |
+
Get document by ID (hỗ trợ cả MongoDB ObjectId và UUID)
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
doc_id: Document ID (MongoDB ObjectId hoặc UUID)
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Document data hoặc None nếu không tìm thấy
|
| 367 |
+
"""
|
| 368 |
+
# Convert to UUID nếu là MongoDB ObjectId
|
| 369 |
+
qdrant_id = self._convert_to_valid_id(doc_id)
|
| 370 |
+
|
| 371 |
+
try:
|
| 372 |
+
result = self.client.retrieve(
|
| 373 |
+
collection_name=self.collection_name,
|
| 374 |
+
ids=[qdrant_id],
|
| 375 |
+
with_payload=True,
|
| 376 |
+
with_vectors=False
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if result:
|
| 380 |
+
point = result[0]
|
| 381 |
+
original_id = point.payload.get('original_id', point.id)
|
| 382 |
+
return {
|
| 383 |
+
"id": original_id, # MongoDB ObjectId
|
| 384 |
+
"qdrant_id": point.id, # UUID trong Qdrant
|
| 385 |
+
"metadata": point.payload
|
| 386 |
+
}
|
| 387 |
+
return None
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(f"Error retrieving document: {e}")
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
def search_by_metadata(
|
| 393 |
+
self,
|
| 394 |
+
filter_conditions: Dict,
|
| 395 |
+
limit: int = 100
|
| 396 |
+
) -> List[Dict[str, Any]]:
|
| 397 |
+
"""
|
| 398 |
+
Search documents by metadata conditions (không cần embedding)
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
filter_conditions: Qdrant filter conditions
|
| 402 |
+
limit: Maximum số results
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
List of matching documents
|
| 406 |
+
"""
|
| 407 |
+
try:
|
| 408 |
+
result = self.client.scroll(
|
| 409 |
+
collection_name=self.collection_name,
|
| 410 |
+
scroll_filter=filter_conditions,
|
| 411 |
+
limit=limit,
|
| 412 |
+
with_payload=True,
|
| 413 |
+
with_vectors=False
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
documents = []
|
| 417 |
+
for point in result[0]: # result is tuple (points, next_page_offset)
|
| 418 |
+
original_id = point.payload.get('original_id', point.id)
|
| 419 |
+
documents.append({
|
| 420 |
+
"id": original_id, # MongoDB ObjectId
|
| 421 |
+
"qdrant_id": point.id, # UUID trong Qdrant
|
| 422 |
+
"metadata": point.payload
|
| 423 |
+
})
|
| 424 |
+
|
| 425 |
+
return documents
|
| 426 |
+
except Exception as e:
|
| 427 |
+
print(f"Error searching by metadata: {e}")
|
| 428 |
+
return []
|
| 429 |
+
|
| 430 |
+
def get_collection_info(self) -> Dict[str, Any]:
|
| 431 |
+
"""
|
| 432 |
+
Lấy thông tin collection
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Collection info
|
| 436 |
+
"""
|
| 437 |
+
info = self.client.get_collection(collection_name=self.collection_name)
|
| 438 |
+
return {
|
| 439 |
+
"vectors_count": info.vectors_count,
|
| 440 |
+
"points_count": info.points_count,
|
| 441 |
+
"status": info.status,
|
| 442 |
+
"config": {
|
| 443 |
+
"distance": info.config.params.vectors.distance,
|
| 444 |
+
"size": info.config.params.vectors.size,
|
| 445 |
+
}
|
| 446 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI và web framework
|
| 2 |
+
fastapi==0.115.5
|
| 3 |
+
uvicorn[standard]==0.32.1
|
| 4 |
+
python-multipart==0.0.20
|
| 5 |
+
|
| 6 |
+
# Gradio cho Hugging Face Spaces
|
| 7 |
+
gradio>=4.0.0
|
| 8 |
+
|
| 9 |
+
# Machine Learning & Embeddings
|
| 10 |
+
torch>=2.0.0
|
| 11 |
+
transformers>=4.50.0
|
| 12 |
+
onnxruntime==1.20.1
|
| 13 |
+
torchvision>=0.15.0
|
| 14 |
+
pillow>=10.0.0
|
| 15 |
+
numpy>=1.24.0
|
| 16 |
+
|
| 17 |
+
# RAG & Reranking (Best Case 2025)
|
| 18 |
+
sentence-transformers>=2.2.0
|
| 19 |
+
httpx>=0.25.0
|
| 20 |
+
|
| 21 |
+
# Vector Database
|
| 22 |
+
qdrant-client>=1.12.1
|
| 23 |
+
grpcio>=1.60.0
|
| 24 |
+
|
| 25 |
+
# Utilities
|
| 26 |
+
pydantic>=2.0.0
|
| 27 |
+
python-dotenv==1.0.0
|
| 28 |
+
|
| 29 |
+
# MongoDB
|
| 30 |
+
pymongo>=4.6.0
|
| 31 |
+
huggingface-hub>=0.20.0
|
| 32 |
+
timm
|
| 33 |
+
einops
|
| 34 |
+
|
| 35 |
+
# PDF Processing
|
| 36 |
+
pypdfium2>=4.30.0
|
| 37 |
+
|
| 38 |
+
httpx>=0.25.0
|
stream_utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SSE (Server-Sent Events) Utilities
|
| 3 |
+
Format streaming responses for real-time chat
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
from typing import Dict, Any, AsyncGenerator
|
| 7 |
+
import asyncio
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def format_sse(event: str, data: Any) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Format data as SSE message
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
event: Event type (token, status, done, error)
|
| 16 |
+
data: Data payload (string or dict)
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Formatted SSE string
|
| 20 |
+
|
| 21 |
+
Example:
|
| 22 |
+
format_sse("token", "Hello")
|
| 23 |
+
# "event: token\ndata: Hello\n\n"
|
| 24 |
+
"""
|
| 25 |
+
if isinstance(data, dict):
|
| 26 |
+
data_str = json.dumps(data, ensure_ascii=False)
|
| 27 |
+
else:
|
| 28 |
+
data_str = str(data)
|
| 29 |
+
|
| 30 |
+
return f"event: {event}\ndata: {data_str}\n\n"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
async def simulate_typing(
|
| 34 |
+
text: str,
|
| 35 |
+
chars_per_chunk: int = 3,
|
| 36 |
+
delay_ms: float = 20
|
| 37 |
+
) -> AsyncGenerator[str, None]:
|
| 38 |
+
"""
|
| 39 |
+
Simulate typing effect by yielding text in chunks
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
text: Full text to stream
|
| 43 |
+
chars_per_chunk: Characters per chunk
|
| 44 |
+
delay_ms: Milliseconds delay between chunks
|
| 45 |
+
|
| 46 |
+
Yields:
|
| 47 |
+
Text chunks
|
| 48 |
+
|
| 49 |
+
Example:
|
| 50 |
+
async for chunk in simulate_typing("Hello world", chars_per_chunk=2):
|
| 51 |
+
yield format_sse("token", chunk)
|
| 52 |
+
"""
|
| 53 |
+
for i in range(0, len(text), chars_per_chunk):
|
| 54 |
+
chunk = text[i:i + chars_per_chunk]
|
| 55 |
+
yield chunk
|
| 56 |
+
await asyncio.sleep(delay_ms / 1000)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
async def stream_text_slowly(
|
| 60 |
+
text: str,
|
| 61 |
+
event_type: str = "token",
|
| 62 |
+
chars_per_chunk: int = 3,
|
| 63 |
+
delay_ms: float = 20
|
| 64 |
+
) -> AsyncGenerator[str, None]:
|
| 65 |
+
"""
|
| 66 |
+
Stream text with typing effect in SSE format
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
text: Text to stream
|
| 70 |
+
event_type: SSE event type
|
| 71 |
+
chars_per_chunk: Characters per chunk
|
| 72 |
+
delay_ms: Delay between chunks
|
| 73 |
+
|
| 74 |
+
Yields:
|
| 75 |
+
SSE formatted chunks
|
| 76 |
+
"""
|
| 77 |
+
async for chunk in simulate_typing(text, chars_per_chunk, delay_ms):
|
| 78 |
+
yield format_sse(event_type, chunk)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Event type constants
|
| 82 |
+
EVENT_STATUS = "status"
|
| 83 |
+
EVENT_TOKEN = "token"
|
| 84 |
+
EVENT_DONE = "done"
|
| 85 |
+
EVENT_ERROR = "error"
|
| 86 |
+
EVENT_METADATA = "metadata"
|
tools_service.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tools Service for LLM Function Calling
|
| 3 |
+
HuggingFace-compatible với prompt engineering
|
| 4 |
+
"""
|
| 5 |
+
import httpx
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
import json
|
| 8 |
+
import asyncio
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ToolsService:
|
| 12 |
+
"""
|
| 13 |
+
Manages external API tools that LLM can call via prompt engineering
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, base_url: str = "https://hoalacrent.io.vn/api/v0", feedback_tracking=None):
|
| 17 |
+
self.base_url = base_url
|
| 18 |
+
self.client = httpx.AsyncClient(timeout=10.0)
|
| 19 |
+
self.feedback_tracking = feedback_tracking # NEW: Feedback tracking
|
| 20 |
+
|
| 21 |
+
def get_tools_definition(self) -> List[Dict]:
|
| 22 |
+
"""
|
| 23 |
+
Return list of tool definitions (OpenAI format style)
|
| 24 |
+
Used for constructing system prompt
|
| 25 |
+
"""
|
| 26 |
+
return [
|
| 27 |
+
{
|
| 28 |
+
"name": "search_events",
|
| 29 |
+
"description": "Tìm kiếm sự kiện phù hợp theo từ khóa, vibe, hoặc thời gian.",
|
| 30 |
+
"parameters": {
|
| 31 |
+
"type": "object",
|
| 32 |
+
"properties": {
|
| 33 |
+
"query": {"type": "string", "description": "Từ khóa tìm kiếm (VD: 'nhạc rock', 'hài kịch')"},
|
| 34 |
+
"vibe": {"type": "string", "description": "Vibe/Mood (VD: 'chill', 'sôi động', 'hẹn hò')"},
|
| 35 |
+
"time": {"type": "string", "description": "Thời gian (VD: 'cuối tuần này', 'tối nay')"}
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"name": "get_event_details",
|
| 41 |
+
"description": "Lấy thông tin chi tiết (giá, địa điểm, thời gian) của sự kiện.",
|
| 42 |
+
"parameters": {
|
| 43 |
+
"type": "object",
|
| 44 |
+
"properties": {
|
| 45 |
+
"event_id": {"type": "string", "description": "ID của sự kiện (MongoDB ID)"}
|
| 46 |
+
},
|
| 47 |
+
"required": ["event_id"]
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"name": "get_purchased_events",
|
| 52 |
+
"description": "Kiểm tra lịch sử các sự kiện user đã mua vé hoặc tham gia.",
|
| 53 |
+
"parameters": {
|
| 54 |
+
"type": "object",
|
| 55 |
+
"properties": {
|
| 56 |
+
"user_id": {"type": "string", "description": "ID của user"}
|
| 57 |
+
},
|
| 58 |
+
"required": ["user_id"]
|
| 59 |
+
}
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"name": "save_feedback",
|
| 63 |
+
"description": "Lưu đánh giá/feedback của user về sự kiện.",
|
| 64 |
+
"parameters": {
|
| 65 |
+
"type": "object",
|
| 66 |
+
"properties": {
|
| 67 |
+
"event_id": {"type": "string", "description": "ID sự kiện"},
|
| 68 |
+
"rating": {"type": "integer", "description": "Số sao đánh giá (1-5)"},
|
| 69 |
+
"comment": {"type": "string", "description": "Nội dung nhận xét"}
|
| 70 |
+
},
|
| 71 |
+
"required": ["event_id", "rating"]
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"name": "save_lead",
|
| 76 |
+
"description": "Lưu thông tin khách hàng quan tâm (Lead).",
|
| 77 |
+
"parameters": {
|
| 78 |
+
"type": "object",
|
| 79 |
+
"properties": {
|
| 80 |
+
"email": {"type": "string"},
|
| 81 |
+
"phone": {"type": "string"},
|
| 82 |
+
"interest": {"type": "string"}
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
async def execute_tool(self, tool_name: str, arguments: Dict, access_token: Optional[str] = None) -> Any:
|
| 89 |
+
"""
|
| 90 |
+
Execute a tool by name with arguments
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
tool_name: Name of the tool
|
| 94 |
+
arguments: Tool arguments
|
| 95 |
+
access_token: JWT token for authenticated API calls
|
| 96 |
+
"""
|
| 97 |
+
print(f"\n🔧 ===== TOOL EXECUTION =====")
|
| 98 |
+
print(f"Tool: {tool_name}")
|
| 99 |
+
print(f"Arguments: {arguments}")
|
| 100 |
+
print(f"Access Token: {'✅ Present' if access_token else '❌ Missing'}")
|
| 101 |
+
if access_token:
|
| 102 |
+
print(f"Token preview: {access_token[:30]}...")
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
if tool_name == "get_event_details":
|
| 106 |
+
return await self._get_event_details(arguments.get("event_id") or arguments.get("event_code"))
|
| 107 |
+
|
| 108 |
+
elif tool_name == "get_purchased_events":
|
| 109 |
+
print(f"→ Calling _get_purchased_events with:")
|
| 110 |
+
print(f" user_id: {arguments.get('user_id')}")
|
| 111 |
+
print(f" access_token: {'✅' if access_token else '❌'}")
|
| 112 |
+
return await self._get_purchased_events(
|
| 113 |
+
arguments.get("user_id"),
|
| 114 |
+
access_token=access_token # Pass access_token
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
elif tool_name == "save_feedback":
|
| 118 |
+
return await self._save_feedback(
|
| 119 |
+
arguments.get("event_id"),
|
| 120 |
+
arguments.get("rating"),
|
| 121 |
+
arguments.get("comment")
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
elif tool_name == "search_events":
|
| 125 |
+
# Note: This usually requires RAG service, so we return a special signal
|
| 126 |
+
# The Agent Service will handle RAG search
|
| 127 |
+
return {"action": "run_rag_search", "query": arguments}
|
| 128 |
+
|
| 129 |
+
elif tool_name == "save_lead":
|
| 130 |
+
# Placeholder for lead saving
|
| 131 |
+
return {"success": True, "message": "Lead saved successfully"}
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
return {"error": f"Unknown tool: {tool_name}"}
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"⚠️ Tool Execution Error: {e}")
|
| 138 |
+
return {"error": str(e)}
|
| 139 |
+
|
| 140 |
+
async def _get_event_details(self, event_id: str) -> Dict:
|
| 141 |
+
"""Call API to get event details"""
|
| 142 |
+
if not event_id:
|
| 143 |
+
return {"error": "Missing event_id"}
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
url = f"{self.base_url}/event/get-event-by-id"
|
| 147 |
+
|
| 148 |
+
response = await self.client.get(url, params={"id": event_id})
|
| 149 |
+
if response.status_code == 200:
|
| 150 |
+
data = response.json()
|
| 151 |
+
if data.get("success"):
|
| 152 |
+
return data.get("data")
|
| 153 |
+
return {"error": "Event not found", "details": response.text}
|
| 154 |
+
except Exception as e:
|
| 155 |
+
return {"error": str(e)}
|
| 156 |
+
|
| 157 |
+
async def _get_purchased_events(self, user_id: str, access_token: Optional[str] = None) -> List[Dict]:
|
| 158 |
+
"""Call API to get purchased events for user (requires auth)"""
|
| 159 |
+
print(f"\n🎫 ===== GET PURCHASED EVENTS =====")
|
| 160 |
+
print(f"User ID: {user_id}")
|
| 161 |
+
print(f"Access Token: {'✅ Present' if access_token else '❌ Missing'}")
|
| 162 |
+
|
| 163 |
+
if not user_id:
|
| 164 |
+
print("⚠️ No user_id provided, returning empty list")
|
| 165 |
+
return []
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
url = f"{self.base_url}/event/get-purchase-event-by-user-id/{user_id}"
|
| 169 |
+
print(f"🔍 API URL: {url}")
|
| 170 |
+
|
| 171 |
+
# Add Authorization header if access_token provided
|
| 172 |
+
headers = {}
|
| 173 |
+
if access_token:
|
| 174 |
+
headers["Authorization"] = f"Bearer {access_token}"
|
| 175 |
+
print(f"🔐 Authorization Header Added:")
|
| 176 |
+
print(f" Bearer {access_token[:30]}...")
|
| 177 |
+
else:
|
| 178 |
+
print(f"⚠️ No access_token - calling API without auth")
|
| 179 |
+
|
| 180 |
+
print(f"📡 Headers: {headers}")
|
| 181 |
+
print(f"🚀 Calling API...")
|
| 182 |
+
|
| 183 |
+
response = await self.client.get(url, headers=headers)
|
| 184 |
+
|
| 185 |
+
print(f"📥 Response Status: {response.status_code}")
|
| 186 |
+
print(f"📦 Response Headers: {dict(response.headers)}")
|
| 187 |
+
|
| 188 |
+
if response.status_code == 200:
|
| 189 |
+
data = response.json()
|
| 190 |
+
print(f"✅ Success! Data keys: {list(data.keys())}")
|
| 191 |
+
events = data.get("data", [])
|
| 192 |
+
print(f"📊 Found {len(events)} purchased events")
|
| 193 |
+
|
| 194 |
+
# Log actual event data
|
| 195 |
+
if events:
|
| 196 |
+
print(f"\n📋 Purchased Events Details:")
|
| 197 |
+
for i, event in enumerate(events, 1):
|
| 198 |
+
print(f"{i}. Event Code: {event.get('eventCode', 'N/A')}")
|
| 199 |
+
print(f" Event Name: {event.get('eventName', 'N/A')}")
|
| 200 |
+
print(f" Event ID: {event.get('_id', 'N/A')}")
|
| 201 |
+
print(f" Full data: {event}")
|
| 202 |
+
|
| 203 |
+
return events
|
| 204 |
+
else:
|
| 205 |
+
print(f"❌ API Error: {response.status_code}")
|
| 206 |
+
print(f"Response body: {response.text[:500]}")
|
| 207 |
+
return []
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"⚠️ Exception in _get_purchased_events: {type(e).__name__}: {e}")
|
| 211 |
+
import traceback
|
| 212 |
+
traceback.print_exc()
|
| 213 |
+
return []
|
| 214 |
+
|
| 215 |
+
async def _save_feedback(self, event_id: str, rating: int, comment: str, user_id: str = None, event_code: str = None) -> Dict:
|
| 216 |
+
"""Save feedback and mark as completed in tracking system"""
|
| 217 |
+
print(f"\n📝 ===== SAVE FEEDBACK =====")
|
| 218 |
+
print(f"Event ID: {event_id}")
|
| 219 |
+
print(f"Event Code: {event_code}")
|
| 220 |
+
print(f"User ID: {user_id}")
|
| 221 |
+
print(f"Rating: {rating}")
|
| 222 |
+
print(f"Comment: {comment}")
|
| 223 |
+
|
| 224 |
+
# TODO: Implement real API call to save feedback
|
| 225 |
+
# For now, just mark in tracking system
|
| 226 |
+
if self.feedback_tracking and user_id and event_code:
|
| 227 |
+
success = self.feedback_tracking.mark_feedback_given(
|
| 228 |
+
user_id=user_id,
|
| 229 |
+
event_code=event_code,
|
| 230 |
+
rating=rating,
|
| 231 |
+
comment=comment
|
| 232 |
+
)
|
| 233 |
+
if success:
|
| 234 |
+
print(f"✅ Feedback tracked in database")
|
| 235 |
+
else:
|
| 236 |
+
print(f"⚠️ Failed to track feedback")
|
| 237 |
+
|
| 238 |
+
return {"success": True, "message": "Feedback recorded"}
|
| 239 |
+
|
| 240 |
+
async def close(self):
|
| 241 |
+
"""Close HTTP client"""
|
| 242 |
+
await self.client.aclose()
|