Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,6 +29,7 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 29 |
import torch
|
| 30 |
import gradio as gr
|
| 31 |
from transformers import pipeline
|
|
|
|
| 32 |
|
| 33 |
audio_model_id = "Sandiago21/whisper-large-v2-greek" # update with your model id
|
| 34 |
audio_pipe = pipeline("automatic-speech-recognition", model=audio_model_id)
|
|
@@ -823,6 +824,9 @@ def tool_executor(state: AgentState):
|
|
| 823 |
"""
|
| 824 |
|
| 825 |
try:
|
|
|
|
|
|
|
|
|
|
| 826 |
webpage_result = ""
|
| 827 |
action = Action.model_validate(state["proposed_action"])
|
| 828 |
|
|
@@ -838,7 +842,7 @@ def tool_executor(state: AgentState):
|
|
| 838 |
query_arg_embeddings = sentence_transformer_model.encode_query(state["proposed_action"]["args"]["query"]).reshape(1, -1)
|
| 839 |
score = float(cosine_similarity(query_embeddings, query_arg_embeddings)[0][0])
|
| 840 |
|
| 841 |
-
if score > 0.80:
|
| 842 |
results = web_search(**action.args)
|
| 843 |
else:
|
| 844 |
logger.info(f"Overwriting user query because the Agent suggested query had score: {state["proposed_action"]["args"]["query"]} - {score}")
|
|
|
|
| 29 |
import torch
|
| 30 |
import gradio as gr
|
| 31 |
from transformers import pipeline
|
| 32 |
+
from langdetect import detect
|
| 33 |
|
| 34 |
audio_model_id = "Sandiago21/whisper-large-v2-greek" # update with your model id
|
| 35 |
audio_pipe = pipeline("automatic-speech-recognition", model=audio_model_id)
|
|
|
|
| 824 |
"""
|
| 825 |
|
| 826 |
try:
|
| 827 |
+
user_input = state["messages"][-1].content
|
| 828 |
+
useer_input_language = detect(user_input)
|
| 829 |
+
|
| 830 |
webpage_result = ""
|
| 831 |
action = Action.model_validate(state["proposed_action"])
|
| 832 |
|
|
|
|
| 842 |
query_arg_embeddings = sentence_transformer_model.encode_query(state["proposed_action"]["args"]["query"]).reshape(1, -1)
|
| 843 |
score = float(cosine_similarity(query_embeddings, query_arg_embeddings)[0][0])
|
| 844 |
|
| 845 |
+
if score > 0.80 or useer_input_language != "en":
|
| 846 |
results = web_search(**action.args)
|
| 847 |
else:
|
| 848 |
logger.info(f"Overwriting user query because the Agent suggested query had score: {state["proposed_action"]["args"]["query"]} - {score}")
|