VcRlAgent commited on
Commit
a01079e
·
1 Parent(s): c19815e

Starter LLM Inference Call

Browse files
Files changed (1) hide show
  1. app/hybrid_rag.py +25 -1
app/hybrid_rag.py CHANGED
@@ -56,12 +56,36 @@ class HybridJiraRAG:
56
  allow_dangerous_deserialization=True
57
  )
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # RAG chain
 
60
  self.rag_chain = RetrievalQA.from_chain_type(
61
  llm=self.llm,
62
  retriever=self.vector_store.as_retriever(search_kwargs={"k": 5}),
63
  return_source_documents=True
64
- )
65
 
66
  def _load_local_llm(self, model_name: str):
67
  """Load LLM locally to use GPU"""
 
56
  allow_dangerous_deserialization=True
57
  )
58
 
59
+ # Create prompt
60
+ prompt = PromptTemplate(
61
+ template="Context: {context}\n\nQuestion: {question}\n\nAnswer:",
62
+ input_variables=["context", "question"]
63
+ )
64
+
65
+ # Format docs function
66
+ def format_docs(docs):
67
+ return "\n\n".join([doc.page_content for doc in docs])
68
+
69
+ # LCEL chain
70
+ retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
71
+
72
+ self.rag_chain = (
73
+ {
74
+ "context": retriever | format_docs,
75
+ "question": RunnablePassthrough()
76
+ }
77
+ | prompt
78
+ | self.llm
79
+ | StrOutputParser()
80
+ )
81
+
82
  # RAG chain
83
+ '''
84
  self.rag_chain = RetrievalQA.from_chain_type(
85
  llm=self.llm,
86
  retriever=self.vector_store.as_retriever(search_kwargs={"k": 5}),
87
  return_source_documents=True
88
+ )'''
89
 
90
  def _load_local_llm(self, model_name: str):
91
  """Load LLM locally to use GPU"""