studies / chains.py
Roland Ding
updated the latestcontent of:
498a219
import asyncio
import openai
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _create_retry_decorator
# from langchain.llms.base import create_base_retry_decorator
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import BaseOutputParser
from application import *
from utility import read_pdf,aterminal_print
llm = ChatOpenAI(
temperature=0.0,
model_name="gpt-3.5-turbo-16k",
openai_api_key=openai.api_key,
max_retries=2,
timeout=30 # this one may help me to solve all the old problems...
)
# retry_decorator = _create_retry_decorator(llm)
class Replacement(BaseOutputParser):
"""Parse the output of an LLM call to a comma-separated list."""
def parse(self, text: str, **kwargs):
"""Parse the output of an LLM call."""
if kwargs:
print(kwargs)
return text.strip().split(", ")
# @retry_decorator
@aterminal_print
async def async_generate(article,name,chain,input_variables={}):
try:
res = await chain.ainvoke(input_variables)
except Exception as e:
print("API Error",str(e))
article[name] = f"API Error: {str(e)}"
return
print("completed",name)
article[name] = res.content
@aterminal_print
async def execute_concurrent(article,prompts):
tasks = []
prompt_list = list(prompts.keys())
while prompt_list:
name = prompt_list.pop(0)
p = prompts[name]
# recurse on the missing tables until all input condiitons are met
missing_inputs = [s for s in p["inputs"] if s not in article]
for x in missing_inputs:
await execute_concurrent(article,{x:app_data["prompts"][x]})
print("executing",p["assessment"],name)
tasks.append(gen_task(p,article))
await asyncio.gather(*tasks)
def replace_term(res,**kwargs):
if "map" in kwargs:
for key,term in kwargs["map"].items():
res.content = res.content.replace(key,term)
return res
post_prompt_maping = {}
post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)
def gen_task(prompt,article):
chat_prompt = gen_chat_prompt(prompt,article)
chain = chat_prompt | llm | post_replace_term
input_variables = gen_input_variables (chat_prompt,prompt)
return async_generate(article=article,name=prompt["name"],chain=chain,input_variables=input_variables)
def gen_chat_prompt(prompt,article):
input_text = "".join([article[s] for s in prompt["inputs"]])
messages = [
("system","You are a helpful AI that can answer questions about clinical trail and operation studies."),
("human",input_text)
]
if len(prompt["chain"]) > 1:
for i in article["logic"]["chain id"]:
messages.append(("human",prompt["chain"][i]))
else:
messages.append(("human",prompt["chain"][0]))
# messages.append(("system",reformat_inst))
return ChatPromptTemplate.from_messages(messages=messages)
def gen_input_variables(chat_prompt,prompt):
input_variables = {}
if "term" in prompt:
app_data["current"]["term"] = prompt["term"][0]["prompting_term"]
if "n_col" in chat_prompt.input_variables:
input_variables["n_col"] = len(prompt["chain"])
if "term" in chat_prompt.input_variables:
input_variables["term"] = app_data["current"]["term"]
return input_variables
if __name__ == "__main__":
# lets try the Blood Loss, Operation Time, and Need for ICU in other folder
sample_artice = ".samples/Ha SK, 2008.pdf"
sample_content,_ = read_pdf(sample_artice)
with open(".prompts/other/Operation Time.txt") as f:
prompt = f.read()
name = "Operation Time"
post_prompt_maping = {}
post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)
chat_prompt = ChatPromptTemplate.from_messages([
("human",sample_artice),
("system",prompt),
])
# experiment with cascading the chain
chain = chat_prompt | llm
chain2 = chain | post_replace_term
# lets try remove from chain
chain2.last.with_retry = True
res = chain2.invoke({"term":name})
print(res.content)