Spaces:
Sleeping
Sleeping
OnlyBiggg
commited on
Commit
·
3661274
1
Parent(s):
8519780
fix ner
Browse files
app/dialogflow/api/v1/dialogflow.py
CHANGED
|
@@ -816,12 +816,12 @@ async def check_exist_user_info(request: Request) -> Response:
|
|
| 816 |
session_info = body.get("sessionInfo", {})
|
| 817 |
parameters = session_info.get("parameters")
|
| 818 |
|
| 819 |
-
is_exist_user_info = dialog_service.check_exist_user_info()
|
| 820 |
|
| 821 |
user_info = {}
|
| 822 |
|
| 823 |
if is_exist_user_info:
|
| 824 |
-
user_info = dialog_service.get_user_info()
|
| 825 |
|
| 826 |
user_name = user_info.get("name")
|
| 827 |
phone_number = user_info.get("phone_number")
|
|
@@ -837,7 +837,7 @@ async def check_exist_user_info(request: Request) -> Response:
|
|
| 837 |
return DialogFlowResponseAPI(parameters=parameters)
|
| 838 |
|
| 839 |
@router.post('/trip/extract-user-name')
|
| 840 |
-
async def
|
| 841 |
body = await request.json()
|
| 842 |
session_info = body.get("sessionInfo", {})
|
| 843 |
parameters = session_info.get("parameters")
|
|
@@ -846,7 +846,7 @@ async def extract_user_name(request: Request) -> Response:
|
|
| 846 |
|
| 847 |
ner: NER = request.app.state.ner
|
| 848 |
|
| 849 |
-
user_name = dialog_service.extract_user_name(text=raw_text_user_name, ner=ner)
|
| 850 |
|
| 851 |
parameters = {
|
| 852 |
"user_name": user_name
|
|
|
|
| 816 |
session_info = body.get("sessionInfo", {})
|
| 817 |
parameters = session_info.get("parameters")
|
| 818 |
|
| 819 |
+
is_exist_user_info = await dialog_service.check_exist_user_info()
|
| 820 |
|
| 821 |
user_info = {}
|
| 822 |
|
| 823 |
if is_exist_user_info:
|
| 824 |
+
user_info = await dialog_service.get_user_info()
|
| 825 |
|
| 826 |
user_name = user_info.get("name")
|
| 827 |
phone_number = user_info.get("phone_number")
|
|
|
|
| 837 |
return DialogFlowResponseAPI(parameters=parameters)
|
| 838 |
|
| 839 |
@router.post('/trip/extract-user-name')
|
| 840 |
+
async def get_user_name(request: Request) -> Response:
|
| 841 |
body = await request.json()
|
| 842 |
session_info = body.get("sessionInfo", {})
|
| 843 |
parameters = session_info.get("parameters")
|
|
|
|
| 846 |
|
| 847 |
ner: NER = request.app.state.ner
|
| 848 |
|
| 849 |
+
user_name = await dialog_service.extract_user_name(text=raw_text_user_name, ner=ner)
|
| 850 |
|
| 851 |
parameters = {
|
| 852 |
"user_name": user_name
|
app/dialogflow/services/dialog_service.py
CHANGED
|
@@ -396,12 +396,12 @@ class DialogService:
|
|
| 396 |
except Exception as e:
|
| 397 |
logger.error(f"Error fetching user info: {e}")
|
| 398 |
return None
|
| 399 |
-
|
| 400 |
-
def extract_user_name(text: str, ner: NER):
|
| 401 |
if text is None:
|
| 402 |
return None
|
| 403 |
|
| 404 |
-
user_name_pred = ner.predict(text=text, entity_tag="PERSON")
|
| 405 |
|
| 406 |
if user_name_pred:
|
| 407 |
user_name = user_name_pred[0]
|
|
|
|
| 396 |
except Exception as e:
|
| 397 |
logger.error(f"Error fetching user info: {e}")
|
| 398 |
return None
|
| 399 |
+
@staticmethod
|
| 400 |
+
async def extract_user_name(text: str, ner: NER):
|
| 401 |
if text is None:
|
| 402 |
return None
|
| 403 |
|
| 404 |
+
user_name_pred = await ner.predict(text=text, entity_tag="PERSON")
|
| 405 |
|
| 406 |
if user_name_pred:
|
| 407 |
user_name = user_name_pred[0]
|
app/ner/services/ner.py
CHANGED
|
@@ -19,10 +19,10 @@ class NER:
|
|
| 19 |
tokenizer=self.tokenizer,
|
| 20 |
device=settings.DEVICE)
|
| 21 |
|
| 22 |
-
def predict(self, text: str, entity_tag: str = None):
|
| 23 |
if self.pipeline is None:
|
| 24 |
raise ValueError("Model not loaded. Please call load_model() first.")
|
| 25 |
-
pred = self.pipeline(text)
|
| 26 |
if entity_tag:
|
| 27 |
return self.extract_entities(pred, entity_tag)
|
| 28 |
return pred
|
|
|
|
| 19 |
tokenizer=self.tokenizer,
|
| 20 |
device=settings.DEVICE)
|
| 21 |
|
| 22 |
+
async def predict(self, text: str, entity_tag: str = None):
|
| 23 |
if self.pipeline is None:
|
| 24 |
raise ValueError("Model not loaded. Please call load_model() first.")
|
| 25 |
+
pred = await self.pipeline(text)
|
| 26 |
if entity_tag:
|
| 27 |
return self.extract_entities(pred, entity_tag)
|
| 28 |
return pred
|