File size: 3,911 Bytes
bb32635
7564bac
bb32635
 
004db8a
bb32635
004db8a
7564bac
bb32635
004db8a
bb32635
004db8a
 
 
 
bb32635
004db8a
bb32635
 
004db8a
 
 
 
 
 
 
 
 
bb32635
 
004db8a
bb32635
004db8a
 
 
 
bb32635
004db8a
 
bb32635
004db8a
bb32635
 
 
 
 
004db8a
bb32635
 
 
7564bac
 
 
 
bb32635
 
 
004db8a
 
 
bb32635
004db8a
bb32635
 
 
 
 
 
 
 
 
 
 
 
 
 
364c11a
 
 
bb32635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
004db8a
bb32635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import asyncio
import openai

from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _create_retry_decorator
from langchain.prompts.chat import ChatPromptTemplate
# from langchain.schema import BaseOutputParser
from application import *

from utility import read_pdf,aterminal_print

llm = ChatOpenAI(
    temperature=0.0,
    model_name="gpt-3.5-turbo-16k",
    openai_api_key=openai.api_key)

retry_decorator = _create_retry_decorator(llm)


# class Replacement(BaseOutputParser):
#     """Parse the output of an LLM call to a comma-separated list."""
#     def parse(self, text: str, **kwargs):
#         """Parse the output of an LLM call."""
#         if kwargs:
#             print(kwargs)
#         return text.strip().split(", ")

@aterminal_print # need to review this.
async def async_generate(article,name,chain,replacement_term=None):
    if replacement_term:
        res = await chain.ainvoke({"term":replacement_term})
    else:
        res = await chain.ainvoke({"term":""})

    print("completed",name)
    article[name] = res.content

@aterminal_print # need to review this.
@retry_decorator
async def execute_concurrent(article,prompts):
    
    tasks = []

    prompt_type = article["logic"]
    prompt_list = list(prompts.keys())

    i = 0
    while prompt_list:
        name = prompt_list.pop(0)
        p = prompts[name]

        missing_inputs = [s for s in p["input_list"] if s not in article]
        for x in missing_inputs:
            await execute_concurrent(article,{x:app_data["prompts"][x]})

        print("executing",p["assessment_step"],name)
        input_text = "".join([article[s] for s in p["input_list"]])
        # with open(f".outputs/{i}_{name}.txt","w+") as f:
        #     f.write(input_text)
        #     f.write(p[prompt_type])
        chat_prompt = ChatPromptTemplate.from_messages([
            ("system","You are a helpful AI that can answer questions about clinical trail and operation studies."),
            ("human",input_text),
            ("system",p[prompt_type]),
        ])

        if "reformat_inst" in p:
            chat_prompt.append(
                ("system",p["reformat_inst"])
            )

        post_prompt_maping = {}
        post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)
        
        chain = chat_prompt | llm | post_replace_term
        if "term" in p:
            first_term = list(p["term"].keys())[0]
            replacement_term = p["term"][first_term]["term_prompt"]
            tasks.append(async_generate(article,name,chain,replacement_term=replacement_term)) # in here the name shall be the term_prompt from the terms triggered
        else:
            tasks.append(async_generate(article,name,chain)) # in here the name shall be the term_prompt from the terms triggered
    
    await asyncio.gather(*tasks)

def replace_term(res,**kwargs):
    if "map" in kwargs:
        for key,term in kwargs["map"].items():
            res.content = res.content.replace(key,term)
    return res

if __name__ == "__main__":
    # lets try the Blood Loss, Operation Time, and Need for ICU in other folder
    sample_artice = ".samples/Ha SK, 2008.pdf"
    sample_content,_ = read_pdf(sample_artice)

    llm = ChatOpenAI(temperature=0.0,model_name="gpt-3.5-turbo-16k")
    with open(".prompts/other/Operation Time.txt") as f:
        prompt = f.read()
        name = "Operation Time"
    
    post_prompt_maping = {}
    post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)

    chain_prompt = ChatPromptTemplate.from_messages([
        ("human",sample_artice),
        ("system",prompt),
    ])

    # experiment with cascading the chain
    chain = chain_prompt | llm 
    chain2 = chain | post_replace_term

    # lets try remove from chain
    chain2.last.with_retry = True
    res = chain2.invoke({"term":name})
    print(res.content)