Kotta commited on
Commit
8189d77
·
1 Parent(s): 244655c

feature(#9): rebased with main and confirmed sendNotification with its params.

Browse files
Brain/src/rising_plugin/guardrails-config/actions/actions.py CHANGED
@@ -17,16 +17,14 @@ import os
17
  import json
18
  import numpy as np
19
 
20
- from langchain.embeddings.openai import OpenAIEmbeddings
21
- from langchain.vectorstores import utils
22
- from langchain.document_loaders.csv_loader import CSVLoader
23
  from langchain.docstore.document import Document
24
 
25
  from Brain.src.common.brain_exception import BrainException
26
  from Brain.src.common.utils import (
27
- OPENAI_API_KEY,
28
- COMMAND_SMS_INDEXS,
29
  COMMAND_BROWSER_OPEN,
 
30
  DEFAULT_GPT_MODEL,
31
  )
32
  from Brain.src.model.req_model import ReqModel
@@ -34,6 +32,7 @@ from Brain.src.model.requests.request_model import BasicReq
34
  from Brain.src.rising_plugin.image_embedding import (
35
  query_image_text,
36
  )
 
37
 
38
  from nemoguardrails.actions import action
39
 
@@ -48,6 +47,11 @@ from Brain.src.rising_plugin.llm.llms import (
48
  GPT_LLM_MODELS,
49
  )
50
 
 
 
 
 
 
51
  """
52
  query is json string with below format
53
  {
@@ -74,7 +78,11 @@ query is json string with below format
74
 
75
  @action()
76
  async def general_question(query):
 
 
77
  """step 0: convert string to json"""
 
 
78
  try:
79
  json_query = json.loads(query)
80
  except Exception as ex:
@@ -87,35 +95,31 @@ async def general_question(query):
87
  setting = ReqModel(json_query["setting"])
88
 
89
  """step 1: handle with gpt-4"""
90
- file_path = os.path.dirname(os.path.abspath(__file__))
91
 
92
- with open(f"{file_path}/phone.json", "r") as infile:
93
- data = json.load(infile)
94
- embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
95
-
96
- query_result = embeddings.embed_query(query)
97
- doc_list = utils.maximal_marginal_relevance(np.array(query_result), data, k=1)
98
- loader = CSVLoader(file_path=f"{file_path}/phone.csv", encoding="utf8")
99
- csv_text = loader.load()
100
 
 
 
 
101
  docs = []
 
 
102
 
103
- for res in doc_list:
104
- docs.append(
105
- Document(
106
- page_content=csv_text[res].page_content, metadata=csv_text[res].metadata
107
- )
 
 
 
108
  )
109
-
110
- chain_data = get_llm_chain(model=model).run(input_documents=docs, question=query)
111
- # test
112
- # if model == GPT_3_5_TURBO or model == GPT_4 or model == GPT_4_32K:
113
- # gpt_llm = GptLLM(model=model)
114
- # chain_data = gpt_llm.get_chain().run(input_documents=docs, question=query)
115
- # elif model == FALCON_7B:
116
- # falcon_llm = FalconLLM()
117
- # chain_data = falcon_llm.get_chain().run(question=query)
118
- falcon_llm = FalconLLM()
119
  try:
120
  result = json.loads(chain_data)
121
  # check image query with only its text
@@ -124,17 +128,16 @@ async def general_question(query):
124
  result["content"] = {
125
  "image_name": query_image_text(result["content"], "", uuid)
126
  }
127
-
128
- # else:
129
- # return result
130
- """check program is message to handle it with falcon llm"""
131
  if result["program"] == "message":
132
- result["content"] = falcon_llm.query(question=query)
 
133
  return str(result)
134
  except ValueError as e:
135
  # Check sms and browser query
136
- if doc_list[0] in COMMAND_SMS_INDEXS:
137
  return str({"program": "sms", "content": chain_data})
138
- elif doc_list[0] in COMMAND_BROWSER_OPEN:
139
  return str({"program": "browser", "content": "https://google.com"})
 
140
  return str({"program": "message", "content": falcon_llm.query(question=query)})
 
17
  import json
18
  import numpy as np
19
 
20
+ from Brain.src.service.train_service import TrainService
 
 
21
  from langchain.docstore.document import Document
22
 
23
  from Brain.src.common.brain_exception import BrainException
24
  from Brain.src.common.utils import (
25
+ COMMAND_SMS_INDEXES,
 
26
  COMMAND_BROWSER_OPEN,
27
+ PINECONE_INDEX_NAME,
28
  DEFAULT_GPT_MODEL,
29
  )
30
  from Brain.src.model.req_model import ReqModel
 
32
  from Brain.src.rising_plugin.image_embedding import (
33
  query_image_text,
34
  )
35
+ from Brain.src.rising_plugin.csv_embed import get_embed
36
 
37
  from nemoguardrails.actions import action
38
 
 
47
  GPT_LLM_MODELS,
48
  )
49
 
50
+ from Brain.src.rising_plugin.pinecone_engine import (
51
+ get_pinecone_index_namespace,
52
+ init_pinecone,
53
+ )
54
+
55
  """
56
  query is json string with below format
57
  {
 
78
 
79
  @action()
80
  async def general_question(query):
81
+ """init falcon model"""
82
+ falcon_llm = FalconLLM()
83
  """step 0: convert string to json"""
84
+ index = init_pinecone(PINECONE_INDEX_NAME)
85
+ train_service = TrainService()
86
  try:
87
  json_query = json.loads(query)
88
  except Exception as ex:
 
95
  setting = ReqModel(json_query["setting"])
96
 
97
  """step 1: handle with gpt-4"""
 
98
 
99
+ query_result = get_embed(query)
100
+ relatedness_data = index.query(
101
+ vector=query_result,
102
+ top_k=1,
103
+ include_values=False,
104
+ namespace=train_service.get_pinecone_index_train_namespace(),
105
+ )
 
106
 
107
+ if len(relatedness_data["matches"]) == 0:
108
+ return str({"program": "message", "content": ""})
109
+ document_id = relatedness_data["matches"][0]["id"]
110
  docs = []
111
+ document = train_service.read_one_document(document_id)
112
+ docs.append(Document(page_content=document["page_content"], metadata=""))
113
 
114
+ """ 1. calling gpt model to categorize for all message"""
115
+ if model in GPT_LLM_MODELS:
116
+ chain_data = get_llm_chain(model=model, setting=setting).run(
117
+ input_documents=docs, question=query
118
+ )
119
+ else:
120
+ chain_data = get_llm_chain(model=DEFAULT_GPT_MODEL, setting=setting).run(
121
+ input_documents=docs, question=query
122
  )
 
 
 
 
 
 
 
 
 
 
123
  try:
124
  result = json.loads(chain_data)
125
  # check image query with only its text
 
128
  result["content"] = {
129
  "image_name": query_image_text(result["content"], "", uuid)
130
  }
131
+ """ 2. check program is message to handle it with falcon llm """
 
 
 
132
  if result["program"] == "message":
133
+ if model == FALCON_7B:
134
+ result["content"] = falcon_llm.query(question=query)
135
  return str(result)
136
  except ValueError as e:
137
  # Check sms and browser query
138
+ if document_id in COMMAND_SMS_INDEXES:
139
  return str({"program": "sms", "content": chain_data})
140
+ elif document_id in COMMAND_BROWSER_OPEN:
141
  return str({"program": "browser", "content": "https://google.com"})
142
+
143
  return str({"program": "message", "content": falcon_llm.query(question=query)})
app.py CHANGED
@@ -15,9 +15,7 @@ app.include_router(
15
  construct_blueprint_browser_api(), prefix="/browser", tags=["ai_browser"]
16
  )
17
 
18
- app.include_router(
19
- construct_blueprint_train_api(), prefix="/train", tags=["ai_train"]
20
- )
21
 
22
 
23
  if __name__ == "__main__":
 
15
  construct_blueprint_browser_api(), prefix="/browser", tags=["ai_browser"]
16
  )
17
 
18
+ app.include_router(construct_blueprint_train_api(), prefix="/train", tags=["ai_train"])
 
 
19
 
20
 
21
  if __name__ == "__main__":