jupiter0913 commited on
Commit
27b5f6d
·
1 Parent(s): a27c5a5

feature(#35): implement train logic.

Browse files
Brain/src/common/utils.py CHANGED
@@ -30,8 +30,8 @@ DEFAULT_GPT_MODEL = "gpt-4"
30
  AGENT_NAME = "RisingBrain Assistant"
31
 
32
  # indexes of relatedness of embedding
33
- COMMAND_SMS_INDEXS = [4, 5]
34
- COMMAND_BROWSER_OPEN = [10]
35
 
36
  # Twilio
37
  ACCOUNT_SID = os.getenv("TWILIO_ACCOUNT_SID")
@@ -86,8 +86,10 @@ def parseJsonFromCompletion(data: str) -> json:
86
  if index == len(result) - 3:
87
  result = result[:index] + replacement + result[index + len(substring):]
88
  # fmt: on
89
- result = json.loads(result.replace("':", '":'))
90
- return result
 
 
91
 
92
 
93
  def parseUrlFromStr(text: str) -> str:
 
30
  AGENT_NAME = "RisingBrain Assistant"
31
 
32
  # indexes of relatedness of embedding
33
+ COMMAND_SMS_INDEXES = ["pWDrks5DO1bEPLlUtQ1f", "LEpAhmFi8tAOQUE7LHZZ"] # 4, 5
34
+ COMMAND_BROWSER_OPEN = ["taVNeDINonUqJWXBlESU"] # 10
35
 
36
  # Twilio
37
  ACCOUNT_SID = os.getenv("TWILIO_ACCOUNT_SID")
 
86
  if index == len(result) - 3:
87
  result = result[:index] + replacement + result[index + len(substring):]
88
  # fmt: on
89
+ try:
90
+ return json.loads(result)
91
+ except Exception as e:
92
+ return result
93
 
94
 
95
  def parseUrlFromStr(text: str) -> str:
Brain/src/model/requests/request_model.py CHANGED
@@ -130,6 +130,14 @@ class TrainContacts(BasicReq):
130
  contacts: list[ContactReq]
131
 
132
 
 
 
 
 
 
 
 
 
133
  """endpoint /browser/item"""
134
 
135
 
@@ -140,3 +148,16 @@ class BrowserItem(BasicReq):
140
 
141
  items: list[ItemReq]
142
  prompt: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  contacts: list[ContactReq]
131
 
132
 
133
+ """endpoint: /document"""
134
+
135
+
136
+ class Document(BasicReq):
137
+ document_id: str
138
+ page_content: str
139
+
140
+
141
  """endpoint /browser/item"""
142
 
143
 
 
148
 
149
  items: list[ItemReq]
150
  prompt: str
151
+
152
+
153
+ """endpoint /train"""
154
+
155
+
156
+ class Train(BasicReq):
157
+ class TrainData(BaseModel):
158
+ page_content: str
159
+ timestamp: float
160
+
161
+ id: str
162
+ data: TrainData
163
+ status: str
Brain/src/model/train_model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """train model:
2
+ {
3
+ "id": "String",
4
+ "data": [{"page_content": "String", "timestamp": 0}],
5
+ "status": "created | updated | deleted",
6
+ }"""
7
+
8
+ from Brain.src.model.requests.request_model import Train
9
+
10
+
11
+ class TrainModel:
12
+ def __init__(self, train_data: Train):
13
+ self.id = train_data.id
14
+ self.data = train_data.data
15
+ self.status = TrainStatus.UPDATED
16
+
17
+
18
+ """train status: created | updated | deleted"""
19
+
20
+
21
+ class TrainStatus:
22
+ CREATED = "created"
23
+ UPDATED = "updated"
24
+ DELETED = "deleted"
Brain/src/rising_plugin/guardrails-config/actions/actions.py CHANGED
@@ -13,20 +13,18 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- import os
17
  import json
18
- import numpy as np
19
 
20
  from langchain.embeddings.openai import OpenAIEmbeddings
21
- from langchain.vectorstores import utils
22
- from langchain.document_loaders.csv_loader import CSVLoader
23
  from langchain.docstore.document import Document
24
 
25
  from Brain.src.common.brain_exception import BrainException
26
  from Brain.src.common.utils import (
27
  OPENAI_API_KEY,
28
- COMMAND_SMS_INDEXS,
29
  COMMAND_BROWSER_OPEN,
 
30
  )
31
  from Brain.src.rising_plugin.image_embedding import (
32
  query_image_text,
@@ -44,20 +42,20 @@ from Brain.src.rising_plugin.llm.llms import (
44
  FALCON_7B,
45
  )
46
 
47
- """
48
- query is json string with below format
49
- {
50
- "query": string,
51
- "model": string,
52
- "uuid": string,
53
- "image_search": bool,
54
- }
55
- """
56
 
57
 
58
  @action()
59
  async def general_question(query):
60
  """step 0: convert string to json"""
 
61
  try:
62
  json_query = json.loads(query)
63
  except Exception as ex:
@@ -69,25 +67,26 @@ async def general_question(query):
69
  image_search = json_query["image_search"]
70
 
71
  """step 1: handle with gpt-4"""
72
- file_path = os.path.dirname(os.path.abspath(__file__))
73
 
74
- with open(f"{file_path}/phone.json", "r") as infile:
75
- data = json.load(infile)
76
  embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
77
 
78
  query_result = embeddings.embed_query(query)
79
- doc_list = utils.maximal_marginal_relevance(np.array(query_result), data, k=1)
80
- loader = CSVLoader(file_path=f"{file_path}/phone.csv", encoding="utf8")
81
- csv_text = loader.load()
 
 
 
 
 
 
82
 
83
  docs = []
84
-
85
- for res in doc_list:
86
- docs.append(
87
- Document(
88
- page_content=csv_text[res].page_content, metadata=csv_text[res].metadata
89
- )
90
- )
91
 
92
  chain_data = get_llm_chain(model=model).run(input_documents=docs, question=query)
93
  # test
@@ -115,8 +114,8 @@ async def general_question(query):
115
  return str(result)
116
  except ValueError as e:
117
  # Check sms and browser query
118
- if doc_list[0] in COMMAND_SMS_INDEXS:
119
  return str({"program": "sms", "content": chain_data})
120
- elif doc_list[0] in COMMAND_BROWSER_OPEN:
121
  return str({"program": "browser", "content": "https://google.com"})
122
  return str({"program": "message", "content": falcon_llm.query(question=query)})
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
 
16
  import json
 
17
 
18
  from langchain.embeddings.openai import OpenAIEmbeddings
19
+ from Brain.src.service.train_service import TrainService
 
20
  from langchain.docstore.document import Document
21
 
22
  from Brain.src.common.brain_exception import BrainException
23
  from Brain.src.common.utils import (
24
  OPENAI_API_KEY,
25
+ COMMAND_SMS_INDEXES,
26
  COMMAND_BROWSER_OPEN,
27
+ PINECONE_INDEX_NAME,
28
  )
29
  from Brain.src.rising_plugin.image_embedding import (
30
  query_image_text,
 
42
  FALCON_7B,
43
  )
44
 
45
+ from Brain.src.rising_plugin.pinecone_engine import (
46
+ get_pinecone_index_namespace,
47
+ init_pinecone,
48
+ )
49
+
50
+
51
+ def get_pinecone_index_train_namespace() -> str:
52
+ return get_pinecone_index_namespace(f"trains")
 
53
 
54
 
55
  @action()
56
  async def general_question(query):
57
  """step 0: convert string to json"""
58
+ index = init_pinecone(PINECONE_INDEX_NAME)
59
  try:
60
  json_query = json.loads(query)
61
  except Exception as ex:
 
67
  image_search = json_query["image_search"]
68
 
69
  """step 1: handle with gpt-4"""
 
70
 
 
 
71
  embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
72
 
73
  query_result = embeddings.embed_query(query)
74
+ relatedness_data = index.query(
75
+ vector=query_result,
76
+ top_k=3,
77
+ include_values=False,
78
+ namespace=get_pinecone_index_train_namespace(),
79
+ )
80
+ documentId = ''
81
+ if len(relatedness_data["matches"]) > 0:
82
+ documentId = relatedness_data["matches"][0]["id"]
83
 
84
  docs = []
85
+ train_service = TrainService()
86
+ documents = train_service.read_all_documents()
87
+ for document in documents:
88
+ if document["document_id"] == documentId:
89
+ docs.append(Document(page_content=document["page_content"], metadata=""))
 
 
90
 
91
  chain_data = get_llm_chain(model=model).run(input_documents=docs, question=query)
92
  # test
 
114
  return str(result)
115
  except ValueError as e:
116
  # Check sms and browser query
117
+ if documentId in COMMAND_SMS_INDEXES:
118
  return str({"program": "sms", "content": chain_data})
119
+ elif documentId in COMMAND_BROWSER_OPEN:
120
  return str({"program": "browser", "content": "https://google.com"})
121
  return str({"program": "message", "content": falcon_llm.query(question=query)})
Brain/src/router/api.py CHANGED
@@ -24,7 +24,6 @@ from Brain.src.rising_plugin.risingplugin import (
24
  handle_chat_completion,
25
  )
26
  from Brain.src.firebase.cloudmessage import send_message, get_tokens
27
- from Brain.src.rising_plugin.csv_embed import csv_embed
28
  from Brain.src.rising_plugin.image_embedding import embed_image_text, query_image_text
29
 
30
  from Brain.src.logs import logger
@@ -162,16 +161,6 @@ def construct_blueprint_api() -> APIRouter:
162
  result={"program": "image", "content": image_response},
163
  )
164
 
165
- """@generator.response(
166
- status_code=200, schema={"message": "message", "result": "test_result"}
167
- )"""
168
-
169
- @router.get("/training")
170
- def csv_training():
171
- csv_embed()
172
-
173
- return assembler.to_response(200, "trained successfully", "")
174
-
175
  """@generator.request_body(
176
  {
177
  "token": "test_token",
 
24
  handle_chat_completion,
25
  )
26
  from Brain.src.firebase.cloudmessage import send_message, get_tokens
 
27
  from Brain.src.rising_plugin.image_embedding import embed_image_text, query_image_text
28
 
29
  from Brain.src.logs import logger
 
161
  result={"program": "image", "content": image_response},
162
  )
163
 
 
 
 
 
 
 
 
 
 
 
164
  """@generator.request_body(
165
  {
166
  "token": "test_token",
Brain/src/router/train_router.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ from Brain.src.common.assembler import Assembler
4
+ from Brain.src.model.requests.request_model import (
5
+ Document,
6
+ )
7
+ from Brain.src.service.train_service import TrainService
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ def construct_blueprint_train_api() -> APIRouter:
13
+ # Assembler
14
+ assembler = Assembler()
15
+
16
+ # Services
17
+ train_service = TrainService()
18
+
19
+ """@generator.response(
20
+ status_code=200, schema={"message": "message", "result": "test_result"}
21
+ )"""
22
+
23
+ @router.get("")
24
+ def read_all_documents():
25
+ try:
26
+ result = train_service.read_all_documents()
27
+ except Exception as e:
28
+ return assembler.to_response(400, "failed to get all documents", "")
29
+ return assembler.to_response(200, "Get all documents list successfully", result)
30
+
31
+ """@generator.response( status_code=200, schema={"message": "message", "result": {"document_id": "document_id",
32
+ "page_content":"page_content"}} )"""
33
+
34
+ @router.get("/{document_id}")
35
+ def read_one_document(document_id: str):
36
+ try:
37
+ result = train_service.read_one_document(document_id)
38
+ except Exception as e:
39
+ return assembler.to_response(400, "fail to get one document", "")
40
+ return assembler.to_response(200, "Get one document successfully", result)
41
+
42
+ """@generator.request_body(
43
+ {
44
+ "token": "test_token",
45
+ "uuid": "test_uuid",
46
+ "page_content": "string",
47
+ }
48
+ )
49
+ @generator.response( status_code=200, schema={"message": "message", "result": {"document_id": "document_id",
50
+ "page_content":"page_content"}} )"""
51
+
52
+ @router.post("")
53
+ def create_document_train(data: Document):
54
+ try:
55
+ result = train_service.create_one_document(data.page_content)
56
+ except Exception as e:
57
+ return assembler.to_response(400, "failed to create one document", "")
58
+ return assembler.to_response(
59
+ 200, "created one document and trained it successfully", result
60
+ )
61
+
62
+ """@generator.request_body(
63
+ {
64
+ "token": "test_token",
65
+ "uuid": "test_uuid",
66
+ "document_id": "string",
67
+ "page_content": "string",
68
+ }
69
+ )
70
+ @generator.response( status_code=200, schema={"message": "message", "result": {"document_id": "document_id",
71
+ "page_content":"page_content"}} )"""
72
+
73
+ @router.put("")
74
+ def update_one_document(data: Document):
75
+ try:
76
+ result = train_service.update_one_document(
77
+ data.document_id, data.page_content
78
+ )
79
+ except Exception as e:
80
+ return assembler.to_response(400, "fail to update one document", "")
81
+ return assembler.to_response(
82
+ 200, "updated one document and trained it successfully", result
83
+ )
84
+
85
+ """@generator.request_body(
86
+ {
87
+ "token": "test_token",
88
+ "uuid": "test_uuid",
89
+ "document_id": "string",
90
+ }
91
+ )
92
+ @generator.response( status_code=200, schema={"message": "message", "result": {"document_id": "document_id"}} )"""
93
+
94
+ @router.delete("/{document_id}")
95
+ def delete_one_document(document_id: str):
96
+ try:
97
+ result = train_service.delete_one_document(document_id)
98
+ except Exception as e:
99
+ return assembler.to_response(400, "fail to delete one train", "")
100
+ return assembler.to_response(
101
+ 200, "deleted one document and train data successfully", result
102
+ )
103
+
104
+ return router
Brain/src/service/train_service.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """service to manage trains"""
2
+ from typing import List, Any
3
+
4
+ from Brain.src.rising_plugin.csv_embed import get_embed
5
+ from Brain.src.rising_plugin.pinecone_engine import (
6
+ get_pinecone_index_namespace,
7
+ update_pinecone,
8
+ init_pinecone,
9
+ delete_pinecone,
10
+ add_pinecone,
11
+ delete_all_pinecone,
12
+ )
13
+
14
+ from firebase_admin import firestore
15
+ import datetime
16
+
17
+
18
+ def to_json(page_content: str):
19
+ return {
20
+ "page_content": page_content,
21
+ "timestamp": datetime.datetime.now().timestamp(),
22
+ }
23
+
24
+
25
+ class TrainService:
26
+ """train (getting embedding) and update pinecone with embeddings by train_id
27
+ train datatype:
28
+ key: id
29
+ values: {id},{data}, {status}"""
30
+
31
+ def __init__(self):
32
+ self.db = firestore.client()
33
+ self.documents_ref = self.db.collection("documents")
34
+
35
+ """read all documents from firestore"""
36
+
37
+ def read_all_documents(self):
38
+ query = self.documents_ref.order_by("timestamp")
39
+ docs = query.stream()
40
+ result = []
41
+ for item in docs:
42
+ item_data = item.to_dict()
43
+ result.append(
44
+ {"document_id": item.id, "page_content": item_data["page_content"]}
45
+ )
46
+ return result
47
+
48
+ """read one document from firestore"""
49
+
50
+ def read_one_document(self, document_id: str):
51
+ doc = self.documents_ref.document(document_id).get()
52
+ if doc.exists:
53
+ return {
54
+ "document_id": document_id,
55
+ "page_content": doc.to_dict()["page_content"],
56
+ }
57
+ else:
58
+ return None
59
+
60
+ """create a new document and train it"""
61
+
62
+ def create_one_document(self, page_content: str):
63
+ # Auto-generate document ID
64
+ auto_generated_doc_ref = self.documents_ref.document()
65
+ auto_generated_doc_ref.set(to_json(page_content))
66
+ auto_generated_document_id = auto_generated_doc_ref.id
67
+ self.train_one_document(auto_generated_document_id, page_content)
68
+ return {"document_id": auto_generated_document_id, "page_content": page_content}
69
+
70
+ """update a document by using id and train it"""
71
+
72
+ def update_one_document(self, document_id: str, page_content: str):
73
+ self.documents_ref.document(document_id).update(to_json(page_content))
74
+ self.train_one_document(document_id, page_content)
75
+ return {"document_id": document_id, "page_content": page_content}
76
+
77
+ """delete a document by using document_id"""
78
+
79
+ def delete_one_document(self, document_id: str):
80
+ self.documents_ref.document(document_id).delete()
81
+ self.delete_one_pinecone(document_id)
82
+ return {"document_id": document_id}
83
+
84
+ def train_all_documents(self) -> str:
85
+ documents = self.read_all_documents()
86
+ result = list()
87
+ pinecone_namespace = self.get_pinecone_index_namespace()
88
+ for item in documents:
89
+ query_result = get_embed(item["page_content"])
90
+ result.append(query_result)
91
+ key = item["document_id"]
92
+ value = f'{item["page_content"]}, {query_result}'
93
+ # get vectoring data(embedding data)
94
+ vectoring_values = get_embed(value)
95
+ add_pinecone(namespace=pinecone_namespace, key=key, value=vectoring_values)
96
+
97
+ return "trained all documents successfully"
98
+
99
+ def train_one_document(self, document_id: str, page_content: str) -> None:
100
+ pinecone_namespace = self.get_pinecone_index_namespace()
101
+ result = list()
102
+ query_result = get_embed(page_content)
103
+ result.append(query_result)
104
+ key = document_id
105
+ value = f"{page_content}, {query_result}"
106
+ # get vectoring data(embedding data)
107
+ vectoring_values = get_embed(value)
108
+ add_pinecone(namespace=pinecone_namespace, key=key, value=vectoring_values)
109
+
110
+ def delete_all(self) -> Any:
111
+ return delete_all_pinecone(self.get_pinecone_index_namespace())
112
+
113
+ def delete_one_pinecone(self, document_id: str) -> Any:
114
+ return delete_pinecone(self.get_pinecone_index_namespace(), document_id)
115
+
116
+ def get_pinecone_index_namespace(self) -> str:
117
+ return get_pinecone_index_namespace(f"trains")
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi import Depends, FastAPI
3
  import uvicorn
4
 
5
  from Brain.src.router.browser_router import construct_blueprint_browser_api
 
6
 
7
  initialize_app()
8
 
@@ -14,6 +15,10 @@ app.include_router(
14
  construct_blueprint_browser_api(), prefix="/browser", tags=["ai_browser"]
15
  )
16
 
 
 
 
 
17
 
18
  if __name__ == "__main__":
19
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  import uvicorn
4
 
5
  from Brain.src.router.browser_router import construct_blueprint_browser_api
6
+ from Brain.src.router.train_router import construct_blueprint_train_api
7
 
8
  initialize_app()
9
 
 
15
  construct_blueprint_browser_api(), prefix="/browser", tags=["ai_browser"]
16
  )
17
 
18
+ app.include_router(
19
+ construct_blueprint_train_api(), prefix="/train", tags=["ai_train"]
20
+ )
21
+
22
 
23
  if __name__ == "__main__":
24
  uvicorn.run(app, host="0.0.0.0", port=7860)