File size: 2,085 Bytes
9099bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import List

from dotenv import load_dotenv
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

from src.common.settings import cfg

load_dotenv()

llm = ChatOpenAI(model=cfg.model.llm, temperature=0)
human = """
A user has queried a data catalogue, which has returned a relevant dataset.

Summarise the relevance of this dataset to the query in under three sentences, using the source snippets provided. Do not say it is unrelated; find a relevant connection. For each sentence, a the relevant citation right after. Repeats are allowed. Use '[SOURCE_NUMBER]' for the citation (e.g. 'The Space Needle is in Seattle [1][2]'). You MUST use ALL citations.

Query: "{query}"

Dataset snippets:

{context}
"""

prompt = ChatPromptTemplate.from_messages([("human", human)])


class CitedAnswer(BaseModel):
    """
    Answer the user question based only on the given sources, and cite the sources used.
    """

    generation: str = Field(
        ...,
        description="A dataset summary linking a users query with a dataset. For each sentence, a the relevant citation right after. Repeats are allowed. Use '[SOURCE_NUMBER]' for the citation (e.g. 'The Space Needle is in Seattle [1][2]'). You MUST use ALL citations.",
    )
    citations: List[int] = Field(
        ...,
        description="The integer IDs of the SPECIFIC sources which justify the summary.",
    )


def format_docs_with_id(docs: List[Document]) -> str:
    formatted = [
        f"Source ID: {i}\nArticle Title: {doc.metadata['title']}\nArticle Snippet: {doc.page_content}"
        for i, doc in enumerate(docs)
    ]
    return "\n\n" + "\n\n".join(formatted)


llm_with_tool = llm.bind_tools([CitedAnswer], tool_choice="CitedAnswer")

output_parser = JsonOutputKeyToolsParser(key_name="CitedAnswer", first_tool_only=True)
answer_citations = prompt | llm_with_tool | output_parser