donecase / rag_agent /tools /rag_query.py
sariskiat's picture
Update rag_agent/tools/rag_query.py
a57a975 verified
"""
Tool for querying Vertex AI RAG corpora and retrieving relevant information.
"""
from vertexai import rag
import configparser
from openai import OpenAI
from pymilvus import MilvusClient
from pymilvus import DataType
import json
# cfp = configparser.RawConfigParser()
# cfp.read('config.ini')
# milvus_uri = cfp.get('example', 'uri')
# token = cfp.get('example', 'token')
# api_key = cfp.get('example', 'api_key')
milvus_uri = "https://in03-6df450fd51eb74c.serverless.gcp-us-west1.cloud.zilliz.com"
token = "48c398b27df543cf49608f5399d157f40346d4b7b6d93aa592210cd4356d55183f1dd11516b41666c57775a267839b8c6dd6b1b6"
api_key = "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"
milvus_client = MilvusClient(uri=milvus_uri, token=token)
# client = MilvusClient(
# uri=milvus_uri,
# user="db_6df450fd51eb74c",
# password="Hv0../Lue;2T(-*/",
# )
print(f"Connected to DB: {milvus_uri} successfully")
client = OpenAI(api_key=api_key)
def get_embedding(text, model="text-embedding-3-large"):
return client.embeddings.create(input = [text], model=model).data[0].embedding
from ..config import (
DEFAULT_DISTANCE_THRESHOLD,
DEFAULT_BUSINESS_TOP_K,
DEFAULT_PRODUCT_TOP_K,
DEFAULT_SERVICE_TOP_K
)
def rag_query(
query: str,
type: str,
) -> dict:
search_params = {"metric_type": "L2", "params": {"level": 2}}
search_vectors = [get_embedding(query)]
faq_response = milvus_client.search(
"faq",
data=search_vectors,
limit=3,
search_params=search_params,
anns_field="question",
output_fields=["text"]
)
page_summary = "\nBitcast คือผู้จัดจำหน่ายอุปกรณ์ Hardware Wallet และอุปกรณ์รักษาความปลอดภัยด้านคริปโต ที่คัดสรรเฉพาะสินค้ามาตรฐานระดับสากลจากแบรนด์ชั้นนำ เช่น OneKey, Trezor, Blockstream, Coinkite, Foundation โดยมุ่งเน้นให้ผู้ใช้งานทุกระดับ—from มือใหม่จนถึงผู้ใช้ระดับโปร—สามารถปกป้องสินทรัพย์ดิจิทัลของตนได้อย่าง ปลอดภัย, ง่าย, และเชื่อถือได้\n\nบริษัทนำเสนอสินค้าที่ครอบคลุมตั้งแต่\n- Hardware Wallet แบบหน้าจอสัมผัส,\n- รุ่นจอขาวดำราคาประหยัด,\n- รุ่น Bluetooth / Air-gap / NFC,\n- ไปจนถึง อุปกรณ์สาย Bitcoin-only ที่เน้นความปลอดภัยสูงสุด\n\nทุกผลิตภัณฑ์ถูกคัดเลือกโดยเน้น 3 ปัจจัยหลัก:\nความปลอดภัยระดับสูง (Security), ความง่ายในการใช้งาน (Usability), และความเข้ากันได้กับอุปกรณ์หลากหลาย (Compatibility)\n\nBitcast ยังให้ความสำคัญกับ การให้ข้อมูลการใช้งาน เช่น คู่มือ, วิดีโอสอน, และหน้าแนะนำการตั้งค่า เพื่อให้ผู้ใช้เริ่มต้นได้อย่างมั่นใจ แม้ไม่เคยใช้ Hardware Wallet มาก่อน\n"
if type == "business":
search_vectors = [get_embedding(query)]
response = milvus_client.search(
"business_info",
data=search_vectors,
limit=DEFAULT_BUSINESS_TOP_K,
search_params=search_params,
anns_field="business_info",
output_fields=["text"]
)
print(response)
elif type == "product":
search_vectors = [get_embedding(query)]
response = milvus_client.search(
"product_info",
data=search_vectors,
limit=DEFAULT_PRODUCT_TOP_K,
search_params=search_params,
anns_field="product_info",
output_fields=["text"]
)
print(response)
elif type == "service":
search_vectors = [get_embedding(query)]
response = milvus_client.search(
"service_info",
data=search_vectors,
limit=DEFAULT_SERVICE_TOP_K,
search_params=search_params,
anns_field="service_info",
output_fields=["text"]
)
print(response)
faq_results = []
for hits in faq_response:
for hit in hits:
item = {
"id": hit.get("id"),
"text": hit.get("entity", {}).get("text", ""),
"score": hit.get("distance"),
}
faq_results.append(item)
results = []
for hits in response:
for hit in hits:
item = {
"id": hit.get("id"),
"text": hit.get("entity", {}).get("text", ""),
"score": hit.get("distance"),
}
results.append(item)
print(str(page_summary) + "\nFAQ result:\n" + str(faq_results) + "\nRAG result:\n" + str(results))
return str(page_summary) + "\nFAQ result:\n" + str(faq_results) + "\nRAG result:\n" + str(results)