backend1 / app /chains.py
wang16888's picture
Update app/chains.py
47de850 verified
import os
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.runnables import RunnablePassthrough
from transformers import AutoTokenizer
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
import schemas
from prompts import (
raw_prompt,
raw_prompt_formatted,
history_prompt_formatted,
standalone_prompt_formatted,
rag_prompt_formatted,
format_context,
tokenizer
)
from data_indexing import DataIndexer
data_indexer = DataIndexer()
llm = HuggingFaceEndpoint(
repo_id=model_id
huggingfacehub_api_token=os.environ['HF_TOKEN'],
max_new_tokens=512,
stop_sequences=[tokenizer.eos_token],
streaming=True,
)
simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
formatted_chain = (
raw_prompt_formatted
| llm
).with_types(input_type=schemas.UserQuestion)
history_chain = (
history_prompt_formatted
| llm
).with_types(input_type=schemas.HistoryInput)
standalone_prompt_formatted =
format_prompt(standalone_prompt)
standalone_chain = standalone_prompt_formatted | llm
generation_chain = rag_prompt_formatted | llm
rag_chain = (
RunnablePassthrough.assign(new_question=standalone_chain)
| {
'context': lambda x:
format_context(search(x['new_question'])),
'standalone_question': lambda x: x['new_question']
}
| generation_chain
).with_types(input_type=schemas.RagInput)