File size: 4,229 Bytes
82b9d78
a6d0043
82b9d78
 
fdccd1a
8dfdf87
82b9d78
 
a6d0043
82b9d78
da2ce27
82b9d78
fdccd1a
 
 
8dfdf87
3a74227
8dfdf87
 
fdccd1a
3a74227
82b9d78
fdccd1a
 
82b9d78
 
 
 
 
 
5cef2da
8dfdf87
 
5cef2da
 
3a74227
 
 
5cef2da
 
fdccd1a
 
82b9d78
8dfdf87
 
82b9d78
 
 
 
 
 
 
 
a6d0043
8dfdf87
 
a6d0043
 
82b9d78
8dfdf87
 
 
82b9d78
 
 
 
 
 
 
 
8dfdf87
 
 
 
c769f48
 
 
 
 
 
 
8dfdf87
 
 
 
 
 
 
c769f48
 
 
 
 
498a219
c769f48
 
 
 
8dfdf87
 
 
 
 
 
 
c769f48
 
8dfdf87
82b9d78
 
 
 
 
 
 
 
fdccd1a
82b9d78
 
 
8dfdf87
82b9d78
 
 
 
 
8dfdf87
82b9d78
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import asyncio
import openai

from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _create_retry_decorator
# from langchain.llms.base import create_base_retry_decorator
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import BaseOutputParser
from application import *

from utility import read_pdf,aterminal_print

llm = ChatOpenAI(
    temperature=0.0,
    model_name="gpt-3.5-turbo-16k",
    openai_api_key=openai.api_key,
    max_retries=2,
    timeout=30 # this one may help me to solve all the old problems...
    )

# retry_decorator = _create_retry_decorator(llm)

class Replacement(BaseOutputParser):
    """Parse the output of an LLM call to a comma-separated list."""
    def parse(self, text: str, **kwargs):
        """Parse the output of an LLM call."""
        if kwargs:
            print(kwargs)
        return text.strip().split(", ")

# @retry_decorator
@aterminal_print
async def async_generate(article,name,chain,input_variables={}):
    try:
        res = await chain.ainvoke(input_variables)
    except Exception as e:
        print("API Error",str(e))
        article[name] = f"API Error: {str(e)}"
        return
    
    print("completed",name)
    article[name] = res.content


@aterminal_print
async def execute_concurrent(article,prompts):
    tasks = []

    prompt_list = list(prompts.keys())

    while prompt_list:
        name = prompt_list.pop(0)
        p = prompts[name]

        # recurse on the missing tables until all input condiitons are met
        missing_inputs = [s for s in p["inputs"] if s not in article]
        for x in missing_inputs:
            await execute_concurrent(article,{x:app_data["prompts"][x]})

        print("executing",p["assessment"],name)
        tasks.append(gen_task(p,article))

    await asyncio.gather(*tasks)

def replace_term(res,**kwargs):
    if "map" in kwargs:
        for key,term in kwargs["map"].items():
            res.content = res.content.replace(key,term)
    return res

post_prompt_maping = {}
post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)

def gen_task(prompt,article):
    chat_prompt = gen_chat_prompt(prompt,article)
    chain = chat_prompt | llm | post_replace_term
    input_variables = gen_input_variables (chat_prompt,prompt)

    return async_generate(article=article,name=prompt["name"],chain=chain,input_variables=input_variables)

def gen_chat_prompt(prompt,article):
    input_text = "".join([article[s] for s in prompt["inputs"]])
    
    messages = [
        ("system","You are a helpful AI that can answer questions about clinical trail and operation studies."),
        ("human",input_text)
    ]

    if len(prompt["chain"]) > 1:
        for i in article["logic"]["chain id"]:
            messages.append(("human",prompt["chain"][i]))
    else:
        messages.append(("human",prompt["chain"][0]))
    # messages.append(("system",reformat_inst))
    
    return ChatPromptTemplate.from_messages(messages=messages)

def gen_input_variables(chat_prompt,prompt):
    input_variables = {}
    if "term" in prompt:
        app_data["current"]["term"] = prompt["term"][0]["prompting_term"]
    
    if "n_col" in chat_prompt.input_variables:
        input_variables["n_col"] = len(prompt["chain"])
    if "term" in chat_prompt.input_variables:
        input_variables["term"] = app_data["current"]["term"]
    return input_variables

if __name__ == "__main__":
    # lets try the Blood Loss, Operation Time, and Need for ICU in other folder
    sample_artice = ".samples/Ha SK, 2008.pdf"
    sample_content,_ = read_pdf(sample_artice)

    with open(".prompts/other/Operation Time.txt") as f:
        prompt = f.read()
        name = "Operation Time"
    
    post_prompt_maping = {}
    post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)

    chat_prompt = ChatPromptTemplate.from_messages([
        ("human",sample_artice),
        ("system",prompt),
    ])

    # experiment with cascading the chain
    chain = chat_prompt | llm 
    chain2 = chain | post_replace_term

    # lets try remove from chain
    chain2.last.with_retry = True
    res = chain2.invoke({"term":name})
    print(res.content)