liumaolin commited on
Commit
941bf07
·
1 Parent(s): e0f42b2

Refactor LlamaCpp initialization to simplify parameter handling and remove unused callback manager

Browse files
src/voice_dialogue/config/llm_config.py CHANGED
@@ -19,6 +19,8 @@ def get_llm_model_params() -> Dict[str, Any]:
19
  # 基础模型参数
20
  model_params = {
21
  'streaming': True,
 
 
22
  'temperature': 0.7,
23
  'top_p': 0.9,
24
  'top_k': 20,
 
19
  # 基础模型参数
20
  model_params = {
21
  'streaming': True,
22
+ 'n_gpu_layers': -1,
23
+ 'n_batch': 1024,
24
  'temperature': 0.7,
25
  'top_p': 0.9,
26
  'top_k': 20,
src/voice_dialogue/services/text/processor.py CHANGED
@@ -2,7 +2,6 @@ import pathlib
2
  import typing
3
 
4
  from langchain_community.chat_models.llamacpp import ChatLlamaCpp
5
- from langchain_core.callbacks import StreamingStdOutCallbackHandler, CallbackManager
6
  from langchain_core.messages import SystemMessage
7
  from langchain_core.prompts import (
8
  ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
@@ -22,20 +21,9 @@ def create_langchain_chat_llamacpp_instance(
22
  logger.info(">>>>>>> Initializing LlamaCpp Langchain instance...")
23
 
24
  model_path = pathlib.Path(local_model_path)
25
- callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
26
  llamacpp_langchain_instance = ChatLlamaCpp(
27
  model_path=str(model_path),
28
- streaming=model_params.get('streaming', True),
29
- n_gpu_layers=model_params.get('n_gpu_layers', -1),
30
- n_batch=model_params.get('n_batch', 512),
31
- n_ctx=model_params.get('n_ctx', 2048),
32
- f16_kv=model_params.get('f16_kv', True),
33
- temperature=model_params.get('temperature', 0.8),
34
- top_k=model_params.get('top_k', 40),
35
- top_p=model_params.get('top_p', 0.95),
36
- max_tokens=model_params.get('n_predict', 256),
37
- # callback_manager=callback_manager,
38
- verbose=False
39
  )
40
 
41
  return llamacpp_langchain_instance
 
2
  import typing
3
 
4
  from langchain_community.chat_models.llamacpp import ChatLlamaCpp
 
5
  from langchain_core.messages import SystemMessage
6
  from langchain_core.prompts import (
7
  ChatPromptTemplate, MessagesPlaceholder, HumanMessagePromptTemplate
 
21
  logger.info(">>>>>>> Initializing LlamaCpp Langchain instance...")
22
 
23
  model_path = pathlib.Path(local_model_path)
 
24
  llamacpp_langchain_instance = ChatLlamaCpp(
25
  model_path=str(model_path),
26
+ **model_params
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
  return llamacpp_langchain_instance