from article import Assessment, Region, Section from pydantic import BaseModel, Field from typing import Optional, List, Dict from enum import Enum from langchain.prompts.chat import ChatPromptTemplate from langchain.chat_models import ChatOpenAI from langchain.schema.runnable.base import RunnableSequence value_map = { "overview": Assessment.overview, "clinical": Assessment.clinical, "radiologic": Assessment.radiologic, "safety": Assessment.safety, "other": Assessment.other, "spine": Region.spine, "extremity": Region.extremity, "all": Region.all, "abstract": Section.abstract, "introduction": Section.introduction, "material and methods": Section.methods, "results": Section.results, "discussion": Section.discussion, "conclusion": Section.conclusion, "references": Section.references, } class Parser(BaseModel): term: str = Field("{term}", description="the pattern to be replaced with the output_term.") region: Region = Field("cervical", alias="region") assessment: Assessment = Field(None, alias="assessment") replacement: str = Field(None, description="the term from last the input to be replaced with.") def parse(self,content): content.replace(self.term,self.replacement) return content # class Path(BaseModel): # maybe too early to generalize this. Lets walk through a normal one for the instruction classifier first. # name: str # inputs: List[str] | str = Field([""], alias="inputs") # variables: Dict[str,str] = Field({"term":""}, alias="variables") # assessment: Assessment = Field(None, description="The clinical trail assessment steps") # chain: RunnableSequence = Field([""], description="The nodes in the path to be executed.") # def run(self,article): # content = " ".join([article[s] for s in self.inputs]) # self.varialbes.update(content=content) # article[self.name] = self.chain.invoke(self.variables) # async def arun(self,article): # pass class ChainClassifier(BaseModel): terms: List[str] = Field([""], alias="terms") region: Region = Field(Region.spine, alias="region") sections: Section = Field(None, alias="sections") path: List[object] = Field(..., description="the automation path to be executed.") chain: RunnableSequence = Field(None, description="the nodes in the path to be executed.") def classify(self,article): content = "".join([article[s] for s in self.sections]) if not self.validate_region(article): return if all([t in content for t in self.terms]): return self.instruction def validate_region(self,article): if self.region == Region.all: return True else: return self.region == article.region def parse_obj_mapped(self,obj,key_map): for k,v in key_map.items(): if k in obj: obj[v] = obj.pop(k) self.parse_obj(obj) def run(self,article): self.chain.invoke(article)