WilliamGazeley
commited on
Commit
·
7c1f337
1
Parent(s):
49c70c9
Implemene Azure search tool
Browse files- requirements.txt +3 -0
- src/config.py +7 -1
- src/functions.py +33 -0
requirements.txt
CHANGED
|
@@ -23,3 +23,6 @@ yfinance==0.2.36
|
|
| 23 |
transformers==4.40.2
|
| 24 |
langchain==0.1.9
|
| 25 |
accelerate==0.27.2
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
transformers==4.40.2
|
| 24 |
langchain==0.1.9
|
| 25 |
accelerate==0.27.2
|
| 26 |
+
azure-search-documents==11.6.0b1
|
| 27 |
+
azure-identity==1.16.0
|
| 28 |
+
|
src/config.py
CHANGED
|
@@ -5,7 +5,13 @@ class Config(BaseSettings):
|
|
| 5 |
hf_token: str = Field(...)
|
| 6 |
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
| 7 |
headless: bool = Field(False, description="Run in headless mode.")
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
| 10 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
| 11 |
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
|
|
|
| 5 |
hf_token: str = Field(...)
|
| 6 |
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
| 7 |
headless: bool = Field(False, description="Run in headless mode.")
|
| 8 |
+
|
| 9 |
+
az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
|
| 10 |
+
az_search_api_key: str = Field(...)
|
| 11 |
+
az_search_idx_name: str = Field("analysis-index")
|
| 12 |
+
az_search_top_k: int = Field(2, description="Max number of results to retrun")
|
| 13 |
+
az_search_min_score: float = Field(12.0, description="Only results above this confidence score is used")
|
| 14 |
+
|
| 15 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
| 16 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
| 17 |
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
src/functions.py
CHANGED
|
@@ -11,6 +11,38 @@ from utils import inference_logger
|
|
| 11 |
from langchain.tools import tool
|
| 12 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
@tool
|
| 15 |
def google_search_and_scrape(query: str) -> dict:
|
| 16 |
"""
|
|
@@ -246,6 +278,7 @@ def get_company_profile(symbol: str) -> dict:
|
|
| 246 |
|
| 247 |
def get_openai_tools() -> List[dict]:
|
| 248 |
functions = [
|
|
|
|
| 249 |
google_search_and_scrape,
|
| 250 |
get_current_stock_price,
|
| 251 |
get_company_news,
|
|
|
|
| 11 |
from langchain.tools import tool
|
| 12 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 13 |
|
| 14 |
+
from azure.core.credentials import AzureKeyCredential
|
| 15 |
+
from azure.search.documents import SearchClient
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
az_creds = AzureKeyCredential(config.az_search_api_key)
|
| 19 |
+
az_search_client = SearchClient(config.az_search_endpoint, config.az_search_idx_name, az_creds)
|
| 20 |
+
|
| 21 |
+
@tool
|
| 22 |
+
def get_company_analysis(query: str) -> dict:
|
| 23 |
+
"""
|
| 24 |
+
Searches through your database of company and crypto analysis, retrieves top 2
|
| 25 |
+
pieces of analysis relevant to your query.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
query (str): The search query
|
| 29 |
+
Returns:
|
| 30 |
+
list: A list of dictionaries containing the pieces of analysis.
|
| 31 |
+
"""
|
| 32 |
+
results = az_search_client.search(
|
| 33 |
+
query_type="simple",
|
| 34 |
+
search_text=query,
|
| 35 |
+
select="title,content",
|
| 36 |
+
include_total_count=True,
|
| 37 |
+
top=config.az_search_top_k
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
output = []
|
| 41 |
+
for x in results:
|
| 42 |
+
if x["@search.score"] >= config.az_search_min_score:
|
| 43 |
+
output.append({"title": x["title"], "content": x["content"]})
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
@tool
|
| 47 |
def google_search_and_scrape(query: str) -> dict:
|
| 48 |
"""
|
|
|
|
| 278 |
|
| 279 |
def get_openai_tools() -> List[dict]:
|
| 280 |
functions = [
|
| 281 |
+
get_company_analysis,
|
| 282 |
google_search_and_scrape,
|
| 283 |
get_current_stock_price,
|
| 284 |
get_company_news,
|