cjber's picture
add citations to model generation
9099bff
from typing import List
from dotenv import load_dotenv
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from src.common.settings import cfg
load_dotenv()
llm = ChatOpenAI(model=cfg.model.llm, temperature=0)
human = """
A user has queried a data catalogue, which has returned a relevant dataset.
Summarise the relevance of this dataset to the query in under three sentences, using the source snippets provided. Do not say it is unrelated; find a relevant connection. For each sentence, a the relevant citation right after. Repeats are allowed. Use '[SOURCE_NUMBER]' for the citation (e.g. 'The Space Needle is in Seattle [1][2]'). You MUST use ALL citations.
Query: "{query}"
Dataset snippets:
{context}
"""
prompt = ChatPromptTemplate.from_messages([("human", human)])
class CitedAnswer(BaseModel):
"""
Answer the user question based only on the given sources, and cite the sources used.
"""
generation: str = Field(
...,
description="A dataset summary linking a users query with a dataset. For each sentence, a the relevant citation right after. Repeats are allowed. Use '[SOURCE_NUMBER]' for the citation (e.g. 'The Space Needle is in Seattle [1][2]'). You MUST use ALL citations.",
)
citations: List[int] = Field(
...,
description="The integer IDs of the SPECIFIC sources which justify the summary.",
)
def format_docs_with_id(docs: List[Document]) -> str:
formatted = [
f"Source ID: {i}\nArticle Title: {doc.metadata['title']}\nArticle Snippet: {doc.page_content}"
for i, doc in enumerate(docs)
]
return "\n\n" + "\n\n".join(formatted)
llm_with_tool = llm.bind_tools([CitedAnswer], tool_choice="CitedAnswer")
output_parser = JsonOutputKeyToolsParser(key_name="CitedAnswer", first_tool_only=True)
answer_citations = prompt | llm_with_tool | output_parser