File size: 10,704 Bytes
1ae5f20
 
67bff56
eb80595
 
938db47
67bff56
 
 
938db47
67bff56
23f27a5
e8d4bd9
67bff56
23f27a5
d13567e
23f27a5
54dac0e
67bff56
 
d13567e
67bff56
 
1ae5f20
67bff56
 
 
 
 
 
e8d4bd9
67bff56
 
 
16b32ee
eb80595
 
 
 
938db47
eb80595
 
938db47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67bff56
54dac0e
eb80595
 
 
 
 
 
 
 
 
 
 
 
 
41496a3
67bff56
938db47
08f8055
e5c6abe
08f8055
 
 
 
 
67bff56
08f8055
 
 
 
 
67bff56
23f27a5
67bff56
 
 
 
08f8055
67bff56
 
 
 
23f27a5
67bff56
744970d
67bff56
 
744970d
67bff56
938db47
744970d
67bff56
744970d
67bff56
 
e9a167b
 
 
 
 
67bff56
e9a167b
 
67bff56
23f27a5
938db47
67bff56
 
49bb7ca
67bff56
f9294e1
2ea0e0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
823b0ef
67bff56
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import re
import gradio as gr
import json
import traceback
import ast
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM
)
from peft import PeftModel

# === CONFIGURATION ===
MODEL_PATHS = [
    "FrAnKu34t23/Construction_Risk_Prediction_Model_v3",
]

BASE_MODEL_ID = "distilgpt2"
ANALYSIS_MODEL_ID = "google/flan-t5-base"

# === LOAD MODELS ===
models, tokenizers = [], []

for path in MODEL_PATHS:
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
    model = PeftModel.from_pretrained(base_model, path).to("cpu").eval()
    models.append(model)
    tokenizers.append(tokenizer)

# === FLAN-T5 for CPU ANALYSIS ===
flan_tokenizer = AutoTokenizer.from_pretrained(ANALYSIS_MODEL_ID)
flan_model = AutoModelForSeq2SeqLM.from_pretrained(ANALYSIS_MODEL_ID).to("cpu").eval()

# === FORMAT INPUT ===
def format_input(scenario_text):
    scenario = scenario_text.strip()
    if not scenario.startswith(", "):
        scenario = ", " + scenario.lstrip(", ")
    return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"

# === TEXT CLEANING + JSON EXTRACTION (FALLBACK USE) ===
def clean_raw_json_string(raw_text):
    cleaned = raw_text.replace("β€˜", "'").replace("’", "'")
    cleaned = cleaned.replace("β€œ", '"').replace("”", '"')
    cleaned = cleaned.replace("''", '"').replace("`", '"').replace("†", "")
    cleaned = re.sub(r'([{{\[,])\s*"', r'\1 "', cleaned)
    cleaned = re.sub(r'"\s*([}}\],])', r'" \1', cleaned)
    return cleaned

def extract_json_object(text):
    pattern = r'\{(?:[^{{}}]|"[^"]*")*\}'
    matches = re.findall(pattern, text, re.DOTALL)
    for match in matches:
        try:
            cleaned = clean_raw_json_string(match)
            hazard_items = re.findall(r'\["([^"]+)"\]', cleaned)
            cleaned = re.sub(r'(\["[^"]+"\]\s*,?\s*)+', '', cleaned)
            if hazard_items and "Hazards" not in cleaned:
                cleaned = cleaned.rstrip('} \n\t,')
                cleaned += ', "Hazards": ' + json.dumps(hazard_items) + '}'
            parsed = json.loads(cleaned)
            if isinstance(parsed, dict):
                return parsed
        except Exception as e:
            print(f"⚠️ extract_json_object failed: {e}")
            continue
    return None

def extract_fields(text):
    def clean_text(t):
        t = t.replace("β€˜", "'").replace("’", "'").replace("β€œ", '"').replace("”", '"')
        t = t.replace("''", '"').replace("`", '"').replace("†", "").replace("Β΄", "")
        t = re.sub(r"[^\x00-\x7F]+", "", t)
        return t

    cleaned = clean_text(text)
    cause = "Unknown"
    injury = "Unknown"
    hazards = []

    match = re.search(r'"?Cause of Accident"?\s*:\s*"([^"]+)",?', cleaned, re.IGNORECASE)
    if match:
        cause = match.group(1).strip()

    match = re.search(r'"?Degree of Injury"?\s*:\s*"(Low|Medium|High)"', cleaned, re.IGNORECASE)
    if match:
        injury = match.group(1).capitalize()

    match = re.search(r'"?Hazards"?\s*:\s*(\[[^\]]+\])', cleaned, re.IGNORECASE)
    if match:
        try:
            hazards_raw = clean_text(match.group(1))
            hazards = ast.literal_eval(hazards_raw)
            hazards = [str(h).strip().strip('"').strip("'") for h in hazards]
        except Exception as e:
            print("⚠️ Hazard parsing failed:", e)
            hazards = []

    structured = {
        "Hazards": hazards,
        "Cause of Accident": cause,
        "Degree of Injury": injury
    }

    return hazards, cause, injury, json.dumps(structured, indent=2)

def extract_json_only(text):
    pattern = r'\{(?:[^{}]|"[^"]*")*\}'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[0] if matches else ""

# === GENERATION FROM EACH MODEL ===
def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_length=inputs["input_ids"].shape[1] + max_length,
            temperature=temperature,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True
        )
    return tokenizer.decode(output[0], skip_special_tokens=True).strip()

# === ANALYSIS WITH FLAN-T5 ===
def analyze_with_cpu_model(raw_outputs):
    json_blobs = []
    for i, text in enumerate(raw_outputs):
        json_part = extract_json_only(text)
        if json_part:
            json_blobs.append(f"Model {i+1} JSON:\n{json_part}")

    summary = "\n\n".join(json_blobs)
    prompt = (
        f"The following are JSON outputs from multiple hazard prediction models:\n\n"
        f"{summary}\n\n"
        f"Please analyze all JSON outputs and return:\n"
        f"Cause of Accident: <natural language summary of the most likely cause>\n"
        f"Degree of Injury: <Low | Medium | High>"
    )

    inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
    with torch.no_grad():
        output = flan_model.generate(
            **inputs,
            max_length=128,
            temperature=0.5,
            top_p=0.9,
            do_sample=True
        )

    decoded = flan_tokenizer.decode(output[0], skip_special_tokens=True).strip()

    match_cause = re.search(r"(?i)cause of accident:\s*(.+)", decoded)
    match_injury = re.search(r"(?i)degree of injury:\s*(low|medium|high)", decoded)

    cause = match_cause.group(1).strip() if match_cause else "Cause not found."
    injury = match_injury.group(1).capitalize() if match_injury else "Unknown"

    return cause, injury

# === MAIN GENERATION FUNCTION ===
def generate_prediction_ensemble(scenario_text, max_len, temperature):
    if not scenario_text.strip():
        return "Please enter a scenario", "", ""

    prompt = format_input(scenario_text)

    raw_outputs = [
        generate_single_model_output(model, tokenizer, prompt, max_length=max_len, temperature=temperature)
        for model, tokenizer in zip(models, tokenizers)
    ]

    cause, injury = analyze_with_cpu_model(raw_outputs)
    raw_output_text = "\n\n".join([f"Model {i+1}:\n{resp}" for i, resp in enumerate(raw_outputs)])
    return cause, injury, raw_output_text

# === FULL GRADIO INTERFACE ===
def create_interface():
    with gr.Blocks(title="Multi-Model Safety Risk Predictor") as interface:
        gr.HTML(f"""
        <h1>🚧 Multi-Model Safety Risk Predictor (CPU-Only)</h1>
        <p><strong>System Overview:</strong></p>
        <ul>
            <li>Loads {len(MODEL_PATHS)} specialized safety prediction models</li>
            <li>Each model analyzes the scenario independently</li>
            <li>CPU-only analysis model integrates all results using advanced reasoning</li>
            <li>Handles conflicting predictions through pattern analysis and majority consensus</li>
            <li>Fully optimized for CPU-only Hugging Face Spaces</li>
        </ul>
        <p><strong>Models Loaded:</strong> {len(models)} / {len(MODEL_PATHS)}</p>
        <p><strong>Base Model:</strong> {BASE_MODEL_ID}</p>
        <p><strong>Analysis Method:</strong> CPU-Only (No external API calls)</p>
        """)

        with gr.Row():
            with gr.Column():
                scenario_input = gr.Textbox(
                    lines=6, 
                    label="Construction Scenario Description",
                    placeholder="Describe the workplace safety incident or scenario..."
                )
                
                gr.Markdown("**Quick Examples:**")
                with gr.Row():
                    ex1 = gr.Button("Chemical Exposure", size="sm")
                    ex2 = gr.Button("Fall Hazard", size="sm") 
                    ex3 = gr.Button("Equipment Malfunction", size="sm")
                    ex4 = gr.Button("Fire Incident", size="sm")
                
                with gr.Row():
                    temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Model Creativity")
                    max_len = gr.Slider(100, 400, 300, 50, label="Response Length")
                
                predict_btn = gr.Button("πŸ” Analyze with Multi-Model Ensemble", variant="primary")

            with gr.Column():
                cause_output = gr.Textbox(
                    label="πŸ“ Integrated Cause Analysis", 
                    lines=4,
                    info="CPU model's integrated analysis of all model outputs"
                )
                degree_output = gr.Textbox(
                    label="πŸ“ˆ Degree of Injury",
                    info="Based on zero-shot classification + model integration"
                )
                
                with gr.Accordion("πŸ“„ Individual Model Outputs", open=False):
                    raw_output = gr.Textbox(label="Raw Model Responses", lines=15)

        # Button action
        predict_btn.click(
            fn=generate_prediction_ensemble,
            inputs=[scenario_input, max_len, temperature],
            outputs=[cause_output, degree_output, raw_output]
        )

        # Quick examples
        ex1.click(fn=lambda: "An employee was working with chemical solvents in a poorly ventilated area without proper respiratory protection. The worker began experiencing dizziness and respiratory distress after prolonged exposure.", outputs=scenario_input)
        ex2.click(fn=lambda: "A construction worker was installing roofing materials on a steep slope without proper fall protection equipment. The worker lost footing on wet materials and fell.", outputs=scenario_input)
        ex3.click(fn=lambda: "During routine maintenance, a hydraulic press malfunctioned due to worn seals. The operator's hand was caught when the press unexpectedly activated.", outputs=scenario_input)
        ex4.click(fn=lambda: "While welding in an area with flammable materials, proper fire safety protocols were not followed. Sparks ignited nearby combustible materials causing a flash fire.", outputs=scenario_input)

        gr.HTML(f"""
        <div style='text-align:center; margin-top:20px;'>
            <p><strong>System Status:</strong> {len(models)} models loaded | CPU-optimized | No external APIs</p>
            <p><em>Built with Multi-Model Ensemble + CPU Analysis + Gradio</em></p>
        </div>
        """)

    return interface

# === RUN INTERFACE ===
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)