mechark commited on
Commit
7b00ff7
·
1 Parent(s): 2b08c2f

Use ChatHuggingFace wrapper for better model support

Browse files
Files changed (1) hide show
  1. src/rag/llm.py +6 -3
src/rag/llm.py CHANGED
@@ -1,4 +1,4 @@
1
- from langchain_huggingface import HuggingFaceEndpoint
2
  from langchain_core.prompts import ChatPromptTemplate
3
 
4
  from src.prompts import SYSTEM_PROMPT
@@ -8,12 +8,15 @@ from src.core.config import settings
8
  def get_chain():
9
  prompt = ChatPromptTemplate.from_template(SYSTEM_PROMPT)
10
 
 
11
  llm = HuggingFaceEndpoint(
12
  repo_id=settings.MODEL_NAME,
13
  huggingfacehub_api_token=settings.HUGGINGFACE_TOKEN,
14
  temperature=settings.MODEL_TEMPERATURE,
15
  max_new_tokens=settings.MODEL_MAX_TOKENS,
16
- task="conversational",
17
  )
 
 
 
18
 
19
- return prompt | llm
 
1
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
2
  from langchain_core.prompts import ChatPromptTemplate
3
 
4
  from src.prompts import SYSTEM_PROMPT
 
8
  def get_chain():
9
  prompt = ChatPromptTemplate.from_template(SYSTEM_PROMPT)
10
 
11
+ # Create base endpoint
12
  llm = HuggingFaceEndpoint(
13
  repo_id=settings.MODEL_NAME,
14
  huggingfacehub_api_token=settings.HUGGINGFACE_TOKEN,
15
  temperature=settings.MODEL_TEMPERATURE,
16
  max_new_tokens=settings.MODEL_MAX_TOKENS,
 
17
  )
18
+
19
+ # Wrap with ChatHuggingFace for better conversational support
20
+ chat_llm = ChatHuggingFace(llm=llm)
21
 
22
+ return prompt | chat_llm