Spaces:
Sleeping
Sleeping
| from citekit.cite_modules.LLM import LLM | |
| from citekit.cite_modules.augment_model import ( | |
| Retriever, | |
| CitationSimplyfier, | |
| Verifier, | |
| Ranker, | |
| ) | |
| from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT, PIPELINE_DOC_CACHE | |
| from citekit.prompt.prompt import Prompt, ALCEDocPrompt, DocPrompt, NewALCEVanillaPrompt | |
| from citekit.Dataset.Dataset import PromptDataset | |
| from citekit.evaluator.evaluator import ( | |
| DefaultEvaluator, | |
| compute_autoais, | |
| test_compute_autoais, | |
| ) | |
| from citekit.utils.utils import ( | |
| sentence, | |
| one_paragraph, | |
| each_make_as, | |
| each_make_as, | |
| make_as, | |
| remove_citations, | |
| compute_str_em, | |
| ) | |
| import json | |
| import argparse | |
| from parser import * | |
| def segment(i, text): | |
| return [make_as("docs")(doc) for doc in text.split("\n") if doc] | |
| def segment_query(text): | |
| return [make_as("query")(doc) for doc in text.split("\n") if doc] | |
| if __name__ == "__main__": | |
| # SETTING ARGS | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--save_path", type=str, default="resu.json", help="Path to the config file" | |
| ) | |
| parser.add_argument( | |
| "--model", type=str, default="gpt-3.5-turbo", help="model name or path" | |
| ) | |
| parser.add_argument("--shots", type=int, default=2, help="number of shots") | |
| parser.add_argument("--ndoc", type=int, default=5, help="number of docs") | |
| parser.add_argument("--pr", action="store_true", help="use cite PR") | |
| parser.add_argument("--rouge", action="store_true", help="use rouge") | |
| parser.add_argument("--temp", type=float, default=0.5, help="temperature") | |
| parser.add_argument("--qa", action="store_true", help="eval qa") | |
| parser.add_argument("--mauve", action="store_true", help="eval mauve") | |
| parser.add_argument("--length", type=bool, default=True, help="eval length") | |
| parser.add_argument("--claims", action="store_true", help="eval claims") | |
| parser.add_argument("--qampari", type=str, default=False, help="eval qampari") | |
| parser.add_argument( | |
| "--dataset", type=str, default="data/asqa_eval_gtr_top100.json", help="dataset" | |
| ) | |
| parser.add_argument( | |
| "--demo", type=str, default="prompts/asqa_default.json", help="demo" | |
| ) | |
| parser.add_argument("--doctype", type=str, default="text", help="demo") | |
| parser.add_argument("--data_num", type=int, default=1000, help="num of data") | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| default="text", | |
| help="mode-granularity: text, extraction or summary", | |
| ) | |
| parser.add_argument("--k", type=float, default=1.5, help="coefficient of em") | |
| parser.add_argument("--topk", type=int, default=2, help="topk") | |
| args = parser.parse_args() | |
| def score(data): | |
| pr = compute_autoais(data) | |
| p = pr["citation_prec"] | |
| r = pr["citation_rec"] | |
| em = compute_str_em(data) | |
| return p + r + args.k * em | |
| return 1 | |
| # DATA LOADING | |
| file_path = args.dataset | |
| demo_path = args.demo | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| dataset = json.load(file) | |
| with open(demo_path, "r", encoding="utf-8") as file: | |
| demo = json.load(file) | |
| data_num = min(args.data_num, len(dataset)) | |
| llm_instruction = demo["one_sentence_instruction"] | |
| query_inst = demo["query_instruction"] | |
| shots = "\n\n".join( | |
| NewALCEVanillaPrompt().load_data( | |
| [demo["demos"][1], demo["demos"][3]], | |
| "question", | |
| answer=lambda data: remove_citations( | |
| sentence("first")(data["answer"])["first"] | |
| ), | |
| INST=lambda _: llm_instruction, | |
| docs=lambda data: "".join( | |
| ALCEDocPrompt().default_load_data(data["docs"][1:2]) | |
| ), | |
| ) | |
| ) | |
| documents = [ | |
| DocPrompt().load_data( | |
| list(enumerate(data["docs"])), | |
| Title=lambda data: data[1]["title"], | |
| Passage=lambda data: data[1][args.mode], | |
| ) | |
| for data in dataset | |
| ] | |
| dataset = PromptDataset( | |
| dataset, | |
| "question", | |
| "answer", | |
| "answers", | |
| "qa_pairs", | |
| "claims", | |
| docs=lambda data: ALCEDocPrompt().default_load_data(data["docs"][: args.ndoc]), | |
| )[:data_num] | |
| prompt = Prompt( | |
| template="<shots><INST><question><ans><docs>\nAnswer:", | |
| components={ | |
| "INST": "{INST}\n\n", | |
| "shots": "{shots}\n", | |
| "question": "Question:{question}\n\n", | |
| "ans": "Prefix:{ans}\n\n", | |
| "docs": "{docs}\n", | |
| }, | |
| ) | |
| queryprompt = Prompt( | |
| template="<INST><question><prev><ans>Please generate at most three queries regarding possible subquestions of the given question. Your queris should be diverse and informative in natual language, splited by a new line.\n", | |
| components={ | |
| "question": "Given the original question: {question}\n", | |
| "ans": "The context is: {ans}\n", | |
| "prev": "\nPrevious queries:\n{prev}\n\n", | |
| "INST": "{INST}\n\n", | |
| }, | |
| ) | |
| retriever_prompt = Prompt(template="<query>", components={"query": "{query}"}) | |
| query_generator = LLM( | |
| model=args.model, prompt_maker=queryprompt, self_prompt={"INST": query_inst} | |
| ) | |
| retriever_prompt = Prompt(template="<query>", components={"query": "{query}"}) | |
| eval = DefaultEvaluator(args) | |
| ranker = Ranker(max_turn=6, iterative=True, fixed_turn=2) | |
| # ranker.set_eval('length', output = 'answer') | |
| # ranker.new_eval('score', score , output = 'answer', docs = 'doc_cache', qa_pairs = 'qa_pairs') | |
| ranker.new_eval("score", score, output="answer", docs="doc_cache") | |
| # PIPELINE CONSTRUCTING | |
| llm = LLM( | |
| model=args.model, | |
| prompt_maker=prompt, | |
| self_prompt={"INST": llm_instruction, "shots": shots}, | |
| max_turn=30, | |
| auto_cite=True, | |
| share_model_with=query_generator, | |
| parallel=True, | |
| ) | |
| pipeline = Pipeline( | |
| save_path=args.save_path, | |
| llm=llm, | |
| module=[ranker, query_generator], | |
| head_prompt_maker=prompt, | |
| evaluator=eval, | |
| dataset=dataset, | |
| ) | |
| retriever = Retriever( | |
| prompt_maker=retriever_prompt, | |
| pipeline=pipeline, | |
| retrieve_by="bm25", | |
| documents=documents, | |
| topk=args.topk, | |
| ) | |
| query_generator.set_target(retriever, post_processing=segment_query) | |
| query_generator.add_to_head("prev", sub=False) | |
| retriever.set_target(llm, post_processing=segment) | |
| llm.set_target(ranker, post_processing=make_as("answer")) | |
| #ranker.set_output(post_processing=lambda x: x["answer"], end=False) | |
| ranker.add_to_head( | |
| "ans", sub=True, process=lambda text: one_paragraph(text["answer"]) | |
| ) | |
| #ranker.set_target(query_generator, post_processing=lambda x: {"ans": x["answer"]}) | |
| pipeline.set_initial_module(query_generator) | |
| pipeline.set_data_keys(["question"]) | |
| simplifier = CitationSimplyfier() | |
| ranker.set_target(simplifier) | |
| simplifier.set_output(end=False) | |
| # graph = PipelineGraph(pipeline=pipeline) | |
| # html = graph.generate_html_embed(results='old/res_attr.json') | |
| # graph.visualize() | |
| # print(html) | |
| # with open('pipeline_.html','w') as file: | |
| # file.write(html) | |
| raise KeyError | |
| pipeline.run_on_dataset(datakeys=['question'], initial_module=query_generator) | |