| | from pydantic import BaseModel |
| | BaseModel.Config.arbitrary_types_allowed = True |
| | from haystack.document_stores import InMemoryDocumentStore |
| | from haystack.nodes import PromptNode, PromptTemplate, AnswerParser, BM25Retriever |
| | from haystack.pipelines import Pipeline |
| | from haystack.utils import print_answers |
| | import chainlit as cl |
| | from datasets import load_dataset |
| | import os |
| | from dotenv import load_dotenv |
| | |
| | dataset = load_dataset("PKU-Alignment/BeaverTails", split="330k_train") |
| | document_store = InMemoryDocumentStore(use_bm25=True) |
| |
|
| | |
| | child_abuse = [entry['prompt'] for entry in dataset if entry['category']['child_abuse']] |
| |
|
| | for prompt in child_abuse: |
| | print(prompt) |
| |
|
| | |
| | for case in child_abuse: |
| | print(case) |
| | |
| | prompts_list = [] |
| |
|
| | for prompt in child_abuse: |
| | prompts_list.append(prompt) |
| |
|
| | from haystack import Document |
| |
|
| | |
| | documents = [Document(content=prompt) for prompt in prompts_list] |
| |
|
| | |
| | document_store.write_documents(documents) |
| |
|
| | |
| | retriever = BM25Retriever(document_store=document_store, top_k=3) |
| |
|
| | prompt_template = PromptTemplate( |
| | prompt = """ |
| | Answer the question truthfully based solely on the given documents. If the documents do not contain the answer, |
| | say that answering is not possible given the available information. |
| | Documents: {join(documents)} |
| | Question: {query} |
| | Answer: |
| | """, |
| | output_parser= AnswerParser() |
| | ) |
| |
|
| | prompt_node = PromptNode( |
| | model_name_or_path= "mistralai/Mistral-7B-Instruct-v0.2", |
| | default_prompt_template= prompt_template |
| | ) |
| |
|
| | generative_pipeline = Pipeline() |
| |
|
| | generative_pipeline.add_node(component=retriever, name='retriever', inputs=["Query"]) |
| |
|
| | generative_pipeline.add_node(component=prompt_node, name="prompt_node", inputs=["retriever"]) |
| |
|
| | @cl.on_message |
| | async def main(message: str): |
| | response = await cl.make_async(generative_pipeline.run(message)) |
| | sentences = response["answers"][0].answer.split("\n") |
| |
|
| | if sentences and not sentences[-1].strip().endswith(('.','?','!')): |
| | sentences.pop() |
| |
|
| |
|
| | result = "\n".join(sentences[1:]) |
| | await cl.Message(author="Bot", content= result).send() |