File size: 6,867 Bytes
67285fb
 
 
 
 
 
 
 
 
 
 
c2a8c39
 
67285fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a8c39
 
67285fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
from typing import Optional, List, Dict
from contextlib import asynccontextmanager
import re
import json

from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gradio as gr

class MedicalReport(BaseModel):
    text: str

class ReportResponse(BaseModel):
    assessment: str

class MedicalAssessmentModel:
    def __init__(self):
        # Initialize model and tokenizer
        model_name = "meta-llama/Llama-2-7b-chat-hf"  # or any other model you prefer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )

    def generate_response(self, messages: List[Dict]) -> str:
        # Combine messages into a single prompt
        prompt = ""
        for msg in messages:
            role = msg['role']
            content = msg['content']
            prompt += f"{role}: {content}\n"

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True,
                top_p=0.9,
                num_return_sequences=1,
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract only the generated part
        response = response[len(self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)):]
        return response.strip()

    def run_env1(self, patient_text: str) -> str:
        """Tool Selection Stage"""
        messages = [
            {
                "role": "system",
                "content": "You are a medical professional expert in selecting appropriate clinical risk assessment tools."
            },
            {
                "role": "user",
                "content": f"""Based on the patient's discharge summary, identify potential disease risks and assessment needs.

Patient Information:
{patient_text}

Please analyze:
1. Primary health concerns
2. Risk factors identified
3. Potential complications
4. Areas requiring risk assessment"""
            }
        ]
        return self.generate_response(messages)

    def run_env2(self, patient_text: str, env1_output: str) -> str:
        """Parameter Extraction Stage"""
        messages = [
            {
                "role": "system",
                "content": "You are a medical professional expert in extracting clinical parameters from patient records."
            },
            {
                "role": "user",
                "content": f"""Extract relevant clinical parameters from the patient's information.

Patient Information:
{patient_text}

Previous Analysis:
{env1_output}

Please provide:
1. Key vital signs
2. Relevant lab values
3. Clinical findings
4. Risk factors identified"""
            }
        ]
        return self.generate_response(messages)

    def run_env3(self, patient_text: str, env1_output: str, env2_output: str) -> str:
        """Risk Interpretation Stage"""
        messages = [
            {
                "role": "system",
                "content": "You are a medical expert specialized in clinical risk assessment and interpretation."
            },
            {
                "role": "user",
                "content": f"""Interpret the identified risks and clinical parameters.

Patient Information:
{patient_text}

Risk Analysis:
{env1_output}

Clinical Parameters:
{env2_output}

Please provide:
1. Risk level assessment for each identified condition
2. Clinical significance of findings
3. Interaction between different risk factors
4. Severity assessment"""
            }
        ]
        return self.generate_response(messages)

    def run_env4(self, patient_text: str, env1_output: str, env2_output: str, env3_output: str) -> str:
        """Final Assessment Stage"""
        messages = [
            {
                "role": "system",
                "content": "You are a medical expert specialized in comprehensive risk assessment and patient care planning."
            },
            {
                "role": "user",
                "content": f"""Based on all previous analyses, provide a comprehensive assessment of the patient's disease risks.

Patient Information:
{patient_text}

Previous Analyses:
Risk Identification: {env1_output}
Parameter Analysis: {env2_output}
Risk Interpretation: {env3_output}

Please provide:
1. Summary of significant disease risks identified
2. Overall risk assessment
3. Key areas of concern
4. Recommended monitoring or preventive measures
5. Suggestions for risk mitigation

Format the response in clear sections with headers."""
            }
        ]
        return self.generate_response(messages)

    def process_report(self, patient_text: str) -> str:
        """Process the entire pipeline and return ENV4 output"""
        try:
            # Run all environments sequentially
            env1_output = self.run_env1(patient_text)
            env2_output = self.run_env2(patient_text, env1_output)
            env3_output = self.run_env3(patient_text, env1_output, env2_output)
            env4_output = self.run_env4(patient_text, env1_output, env2_output, env3_output)
            
            return env4_output
        except Exception as e:
            return f"Error in processing: {str(e)}"

def create_gradio_interface():
    model = MedicalAssessmentModel()
    
    def analyze_text(text):
        return model.process_report(text)
    
    iface = gr.Interface(
        fn=analyze_text,
        inputs=gr.Textbox(
            lines=10, 
            placeholder="Enter patient medical report here...",
            label="Medical Report"
        ),
        outputs=gr.Textbox(
            lines=15,
            label="Risk Assessment Report"
        ),
        title="Medical Report Risk Assessment",
        description="Enter a medical report to get a comprehensive risk assessment. The system will analyze the report through multiple stages and provide a final assessment.",
        examples=[
            ["Patient was admitted with chest pain and shortness of breath. History of hypertension and diabetes. BP 160/95, HR 98. Recent smoker with 30 pack-year history."],
            ["83-year-old female presents with confusion and fever. Recent fall at home. History of osteoporosis and mild cognitive impairment. Lives alone. Temperature 38.5C, BP 135/85."]
        ]
    )
    return iface

if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch(server_name="0.0.0.0", server_port=7860)