studies / instruction.py
Roland Ding
miscellaneous update
d1682ed
from article import Assessment, Region, Section
from pydantic import BaseModel, Field
from typing import Optional, List, Dict
from enum import Enum
from langchain.prompts.chat import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.schema.runnable.base import RunnableSequence
value_map = {
"overview": Assessment.overview,
"clinical": Assessment.clinical,
"radiologic": Assessment.radiologic,
"safety": Assessment.safety,
"other": Assessment.other,
"spine": Region.spine,
"extremity": Region.extremity,
"all": Region.all,
"abstract": Section.abstract,
"introduction": Section.introduction,
"material and methods": Section.methods,
"results": Section.results,
"discussion": Section.discussion,
"conclusion": Section.conclusion,
"references": Section.references,
}
class Parser(BaseModel):
term: str = Field("{term}", description="the pattern to be replaced with the output_term.")
region: Region = Field("cervical", alias="region")
assessment: Assessment = Field(None, alias="assessment")
replacement: str = Field(None, description="the term from last the input to be replaced with.")
def parse(self,content):
content.replace(self.term,self.replacement)
return content
# class Path(BaseModel): # maybe too early to generalize this. Lets walk through a normal one for the instruction classifier first.
# name: str
# inputs: List[str] | str = Field([""], alias="inputs")
# variables: Dict[str,str] = Field({"term":""}, alias="variables")
# assessment: Assessment = Field(None, description="The clinical trail assessment steps")
# chain: RunnableSequence = Field([""], description="The nodes in the path to be executed.")
# def run(self,article):
# content = " ".join([article[s] for s in self.inputs])
# self.varialbes.update(content=content)
# article[self.name] = self.chain.invoke(self.variables)
# async def arun(self,article):
# pass
class ChainClassifier(BaseModel):
terms: List[str] = Field([""], alias="terms")
region: Region = Field(Region.spine, alias="region")
sections: Section = Field(None, alias="sections")
path: List[object] = Field(..., description="the automation path to be executed.")
chain: RunnableSequence = Field(None, description="the nodes in the path to be executed.")
def classify(self,article):
content = "".join([article[s] for s in self.sections])
if not self.validate_region(article):
return
if all([t in content for t in self.terms]):
return self.instruction
def validate_region(self,article):
if self.region == Region.all:
return True
else:
return self.region == article.region
def parse_obj_mapped(self,obj,key_map):
for k,v in key_map.items():
if k in obj:
obj[v] = obj.pop(k)
self.parse_obj(obj)
def run(self,article):
self.chain.invoke(article)