Spaces:
Sleeping
Sleeping
Commit ·
73c8596
1
Parent(s): 828a9b4
changed model
Browse files- app/embedding/embeder.py +8 -8
- app/schemas/request_models.py +4 -2
- app/utils/model_loader.py +11 -6
app/embedding/embeder.py
CHANGED
|
@@ -11,12 +11,12 @@ def get_query_embedding(query_spec: QuerySpec, embedding_model):
|
|
| 11 |
q = query_spec.raw_query
|
| 12 |
e_main = embedding_model.embed_query(q)
|
| 13 |
expansions = []
|
| 14 |
-
if "procedure" in query_spec.entities:
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
return e_main,expansions
|
|
|
|
| 11 |
q = query_spec.raw_query
|
| 12 |
e_main = embedding_model.embed_query(q)
|
| 13 |
expansions = []
|
| 14 |
+
# if "procedure" in query_spec.entities:
|
| 15 |
+
# procedure_value = query_spec.entities['procedure']
|
| 16 |
+
# # Handle both string and list values
|
| 17 |
+
# if isinstance(procedure_value, list):
|
| 18 |
+
# procedure_str = ", ".join(procedure_value)
|
| 19 |
+
# else:
|
| 20 |
+
# procedure_str = procedure_value
|
| 21 |
+
# expansions.append(f"{q} OR {procedure_str} procedures related")
|
| 22 |
return e_main,expansions
|
app/schemas/request_models.py
CHANGED
|
@@ -4,8 +4,10 @@ import json
|
|
| 4 |
class QuerySpec(BaseModel):
|
| 5 |
raw_query: str
|
| 6 |
intent: str
|
| 7 |
-
entities: Dict[str, Union[str, List[str]]]
|
| 8 |
-
constraints : Dict[str, Any]
|
|
|
|
|
|
|
| 9 |
answer_type: str
|
| 10 |
followups: Optional[List[str]] = []
|
| 11 |
|
|
|
|
| 4 |
class QuerySpec(BaseModel):
|
| 5 |
raw_query: str
|
| 6 |
intent: str
|
| 7 |
+
# entities: Dict[str, Union[str, List[str]]]
|
| 8 |
+
# constraints : Dict[str, Any]
|
| 9 |
+
entities: Optional[Any] = None
|
| 10 |
+
constraints: Optional[Any] = None
|
| 11 |
answer_type: str
|
| 12 |
followups: Optional[List[str]] = []
|
| 13 |
|
app/utils/model_loader.py
CHANGED
|
@@ -9,6 +9,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
# from langchain_openai import OpenAIEmbeddings
|
| 11 |
from langchain_community.embeddings import OpenAIEmbeddings
|
|
|
|
| 12 |
class ConfigLoader:
|
| 13 |
def __init__(self):
|
| 14 |
print(f"Loading config....")
|
|
@@ -42,12 +43,16 @@ class ModelLoader(BaseModel):
|
|
| 42 |
elif self.model_provider =="gemini":
|
| 43 |
print("Loading model from gemini:")
|
| 44 |
load_dotenv()
|
| 45 |
-
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 46 |
-
model_name = self.config["llm"]["gemini"]["model_name"]
|
| 47 |
-
llm = ChatGoogleGenerativeAI(
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
elif self.model_provider =="openai":
|
| 52 |
load_dotenv()
|
| 53 |
print("Loading model from openai:")
|
|
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
# from langchain_openai import OpenAIEmbeddings
|
| 11 |
from langchain_community.embeddings import OpenAIEmbeddings
|
| 12 |
+
from langchain_openai import ChatOpenAI
|
| 13 |
class ConfigLoader:
|
| 14 |
def __init__(self):
|
| 15 |
print(f"Loading config....")
|
|
|
|
| 43 |
elif self.model_provider =="gemini":
|
| 44 |
print("Loading model from gemini:")
|
| 45 |
load_dotenv()
|
| 46 |
+
# gemini_api_key = os.getenv("GEMINI_API_KEY")
|
| 47 |
+
# model_name = self.config["llm"]["gemini"]["model_name"]
|
| 48 |
+
# llm = ChatGoogleGenerativeAI(
|
| 49 |
+
# model=model_name,
|
| 50 |
+
# google_api_key= gemini_api_key
|
| 51 |
+
# )
|
| 52 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 53 |
+
llm = ChatOpenAI(model="gpt-4o-mini",api_key=openai_api_key )
|
| 54 |
+
|
| 55 |
+
|
| 56 |
elif self.model_provider =="openai":
|
| 57 |
load_dotenv()
|
| 58 |
print("Loading model from openai:")
|