omnibook-base / src /rag.py
REXPro's picture
Update src/rag.py
9d631c2 verified
Raw
History Blame Contribute Delete
2.59 kB
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
def get_llm():
model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.1
)
return HuggingFacePipeline(pipeline=pipe)
def rewrite_query(question, chat_history):
history_text = ""
for turn in chat_history[-3:]:
history_text += f"User: {turn['question']}\n"
history_text += f"Assistant: {turn['answer'][:150]}\n"
rewrite_prompt = """Given the conversation history and a follow-up
question, rewrite the question into a standalone search query that will
retrieve relevant information from a document.
Conversation History:
{history}
Follow-up Question: {question}
Standalone Search Query:"""
prompt = ChatPromptTemplate.from_template(rewrite_prompt)
llm = get_llm()
chain = prompt | llm | StrOutputParser()
rewritten = chain.invoke({
"history": history_text if history_text else "(none)",
"question": question,
})
return rewritten.strip()
def answer_with_memory(vectorstore, question, chat_history, k=6):
search_query = rewrite_query(question, chat_history)
retrieved_docs = vectorstore.similarity_search(search_query, k=k)
context_parts = []
for i, doc in enumerate(retrieved_docs, 1):
context_parts.append(f"[Source {i}]\n{doc.page_content}")
context = "\n\n".join(context_parts)
history_text = ""
for turn in chat_history:
history_text += f"User: {turn['question']}\n"
history_text += f"Assistant: {turn['answer']}\n"
template = """You are a highly efficient and accurate AI assistant for a document Q&A system.
You MUST answer the user's question directly based ONLY on the provided Context.
Conversation History:
{history}
Context:
{context}
Question: {question}
Answer:"""
prompt = ChatPromptTemplate.from_template(template)
llm = get_llm()
chain = prompt | llm | StrOutputParser()
answer = chain.invoke({
"history": history_text,
"context": context,
"question": question,
})
return answer, retrieved_docs