summary / chains.py
Roland Ding
10.9.26.75 updated ui, data, features, and backend
004db8a
import asyncio
import openai
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _create_retry_decorator
from langchain.prompts.chat import ChatPromptTemplate
# from langchain.schema import BaseOutputParser
from application import *
from utility import read_pdf,aterminal_print
llm = ChatOpenAI(
temperature=0.0,
model_name="gpt-3.5-turbo-16k",
openai_api_key=openai.api_key)
retry_decorator = _create_retry_decorator(llm)
# class Replacement(BaseOutputParser):
# """Parse the output of an LLM call to a comma-separated list."""
# def parse(self, text: str, **kwargs):
# """Parse the output of an LLM call."""
# if kwargs:
# print(kwargs)
# return text.strip().split(", ")
@aterminal_print # need to review this.
async def async_generate(article,name,chain,replacement_term=None):
if replacement_term:
res = await chain.ainvoke({"term":replacement_term})
else:
res = await chain.ainvoke({"term":""})
print("completed",name)
article[name] = res.content
@aterminal_print # need to review this.
@retry_decorator
async def execute_concurrent(article,prompts):
tasks = []
prompt_type = article["logic"]
prompt_list = list(prompts.keys())
i = 0
while prompt_list:
name = prompt_list.pop(0)
p = prompts[name]
missing_inputs = [s for s in p["input_list"] if s not in article]
for x in missing_inputs:
await execute_concurrent(article,{x:app_data["prompts"][x]})
print("executing",p["assessment_step"],name)
input_text = "".join([article[s] for s in p["input_list"]])
# with open(f".outputs/{i}_{name}.txt","w+") as f:
# f.write(input_text)
# f.write(p[prompt_type])
chat_prompt = ChatPromptTemplate.from_messages([
("system","You are a helpful AI that can answer questions about clinical trail and operation studies."),
("human",input_text),
("system",p[prompt_type]),
])
if "reformat_inst" in p:
chat_prompt.append(
("system",p["reformat_inst"])
)
post_prompt_maping = {}
post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)
chain = chat_prompt | llm | post_replace_term
if "term" in p:
first_term = list(p["term"].keys())[0]
replacement_term = p["term"][first_term]["term_prompt"]
tasks.append(async_generate(article,name,chain,replacement_term=replacement_term)) # in here the name shall be the term_prompt from the terms triggered
else:
tasks.append(async_generate(article,name,chain)) # in here the name shall be the term_prompt from the terms triggered
await asyncio.gather(*tasks)
def replace_term(res,**kwargs):
if "map" in kwargs:
for key,term in kwargs["map"].items():
res.content = res.content.replace(key,term)
return res
if __name__ == "__main__":
# lets try the Blood Loss, Operation Time, and Need for ICU in other folder
sample_artice = ".samples/Ha SK, 2008.pdf"
sample_content,_ = read_pdf(sample_artice)
llm = ChatOpenAI(temperature=0.0,model_name="gpt-3.5-turbo-16k")
with open(".prompts/other/Operation Time.txt") as f:
prompt = f.read()
name = "Operation Time"
post_prompt_maping = {}
post_replace_term = lambda res,map=post_prompt_maping:replace_term(res,map=map)
chain_prompt = ChatPromptTemplate.from_messages([
("human",sample_artice),
("system",prompt),
])
# experiment with cascading the chain
chain = chain_prompt | llm
chain2 = chain | post_replace_term
# lets try remove from chain
chain2.last.with_retry = True
res = chain2.invoke({"term":name})
print(res.content)