FrAnKu34t23 commited on
Commit
54dac0e
Β·
verified Β·
1 Parent(s): 744970d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -165
app.py CHANGED
@@ -9,58 +9,73 @@ import os
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from peft import PeftModel
11
 
12
- # Configuration - Using better base models
13
  MODEL_PATHS = [
14
  "FrAnKu34t23/Construction_Risk_Prediction_Model_v3"
15
  ]
16
- # Better base model options - choose one based on your needs
17
- BASE_MODEL_ID = "microsoft/DialoGPT-medium" # Better conversational model
18
- # Alternative options:
19
- # BASE_MODEL_ID = "gpt2-medium" # Larger GPT-2
20
- # BASE_MODEL_ID = "microsoft/DialoGPT-large" # Even better but slower
21
 
22
  models = []
23
  tokenizers = []
 
24
 
25
- # Initialize better models for analysis
26
- injury_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
27
-
28
- # Use a more capable model for text analysis and reasoning
29
- analysis_model = pipeline(
30
- "text-generation",
31
- model="microsoft/DialoGPT-large", # Better reasoning capabilities
32
- device=0 if torch.cuda.is_available() else -1
33
  )
34
 
35
  def classify_injury_zero_shot(description):
36
  candidate_labels = [
37
  "Low severity injury (minor discomfort or bruise) or unrelevant cases",
38
- "Medium severity injury (sprain, strain, moderate pain)",
39
  "High severity injury (fracture, major trauma, amputation, fatal)"
40
  ]
41
  label_mapping = {
42
  candidate_labels[0]: "Low",
43
- candidate_labels[1]: "Medium",
44
  candidate_labels[2]: "High"
45
  }
46
  result = injury_classifier(description, candidate_labels)
47
  return label_mapping[result['labels'][0]]
48
 
49
  def load_models():
50
- global models, tokenizers
51
  try:
52
- for path in MODEL_PATHS:
 
 
53
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
54
  tokenizer.pad_token = tokenizer.eos_token
55
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
 
 
 
 
 
 
 
 
 
 
56
  model = PeftModel.from_pretrained(base_model, path)
57
  model.eval()
 
 
58
  models.append(model)
59
  tokenizers.append(tokenizer)
60
- print("βœ… All models loaded.")
 
 
61
  return True
 
62
  except Exception as e:
63
  print(f"❌ Model loading failed: {e}")
 
64
  return False
65
 
66
  def format_input(scenario_text):
@@ -69,11 +84,11 @@ def format_input(scenario_text):
69
  scenario = ", " + scenario.lstrip(", ")
70
  return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"
71
 
72
- def generate_all_model_outputs(prompt, max_length=300, temperature=0.7):
73
- outputs = []
74
- for model, tokenizer in zip(models, tokenizers):
75
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
76
- inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
77
 
78
  with torch.no_grad():
79
  output = model.generate(
@@ -84,207 +99,275 @@ def generate_all_model_outputs(prompt, max_length=300, temperature=0.7):
84
  top_k=50,
85
  repetition_penalty=1.1,
86
  pad_token_id=tokenizer.pad_token_id,
87
- eos_token_id=tokenizer.eos_token_id
 
88
  )
89
 
90
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
91
- outputs.append(f"=== RAW RESPONSE START ===\n{decoded}\n=== RAW RESPONSE END ===")
92
- return outputs
 
 
 
93
 
94
- def extract_scenario_from_prompt(prompt):
95
- try:
96
- return re.sub(r"^.*predict potential hazards and injuries\.\s*", "", prompt)
97
- except:
98
- return prompt
 
 
 
 
 
 
 
99
 
100
- def parse_json_from_raw_output(raw_output):
101
- """Extract JSON from raw model output"""
 
 
102
  try:
103
- # Look for JSON pattern in the raw output
104
- json_match = re.search(r'\{.*?\}', raw_output, re.DOTALL)
105
- if json_match:
106
- json_str = json_match.group(0)
107
- return json.loads(json_str)
108
- return None
 
109
  except:
110
- return None
111
-
112
- def extract_structured_data_from_outputs(raw_outputs):
113
- """Extract and combine structured JSON data from all model outputs"""
114
- all_json_data = []
115
-
116
- for output in raw_outputs:
117
- json_data = parse_json_from_raw_output(output)
118
- if json_data:
119
- all_json_data.append(json_data)
120
-
121
- return all_json_data
122
-
123
- def analyze_with_advanced_hf_model(raw_outputs, zero_shot_injury, structured_data):
124
- """Replace Gemini Pro functionality with advanced HF model analysis"""
125
 
126
- # Prepare the analysis prompt similar to original Gemini prompt
127
- structured_info = ""
128
- if structured_data:
129
- structured_info = "\n\nStructured data extracted from models:\n"
130
- for i, data in enumerate(structured_data, 1):
131
- structured_info += f"Model {i}: {json.dumps(data, indent=2)}\n"
132
 
133
- prompt = f"""You are a workplace safety analyst. Below are raw text outputs from construction safety prediction models.
134
-
135
- Your tasks:
136
- - Compare and merge the model outputs
137
- - Summarize the most plausible cause of accident in natural language
138
- - Infer the degree of injury by considering all outputs and classifier suggestion
139
-
140
- Classifier prediction for Degree of Injury: {zero_shot_injury}
141
 
142
- Model Outputs:
143
- {raw_outputs[0]}
144
 
145
- {raw_outputs[1] if len(raw_outputs) > 1 else ""}
146
 
147
- {raw_outputs[2] if len(raw_outputs) > 2 else ""}
148
-
149
- {structured_info}
150
-
151
- Based on this analysis, provide a concise response in this format:
152
- Cause of Accident: [single clear sentence]
153
- Degree of Injury: [Low/Medium/High]
154
-
155
- Analysis:"""
156
 
157
  try:
158
- # Use the analysis model to generate response
159
- response = analysis_model(
160
- prompt,
161
- max_length=len(prompt.split()) + 100,
162
- temperature=0.3, # Lower temperature for more consistent analysis
163
  do_sample=True,
164
- pad_token_id=analysis_model.tokenizer.eos_token_id
 
165
  )
166
 
167
  generated_text = response[0]['generated_text']
168
- # Extract only the generated part after the prompt
169
  analysis_result = generated_text.replace(prompt, "").strip()
170
 
171
- # If the analysis doesn't contain the required format, create it
172
- if "Cause of Accident:" not in analysis_result:
173
- # Fallback analysis based on structured data
174
- cause = "Multiple safety protocol violations identified"
175
- if structured_data:
176
- causes = []
177
- for data in structured_data:
178
- if isinstance(data, dict) and "Cause of Accident" in data:
179
- causes.append(data["Cause of Accident"])
180
- if causes:
181
- cause = causes[0] # Take the first cause found
182
-
183
- analysis_result = f"Cause of Accident: {cause}\nDegree of Injury: {zero_shot_injury}"
184
 
185
- return analysis_result
 
 
 
 
186
 
187
  except Exception as e:
188
- print("❌ Advanced HF model analysis failed:", e)
189
- # Fallback using structured data if available
190
- if structured_data and len(structured_data) > 0:
191
- first_data = structured_data[0]
192
- cause = first_data.get("Cause of Accident", "Safety protocol violation")
193
- injury = first_data.get("Degree of Injury", zero_shot_injury)
194
- return f"Cause of Accident: {cause}\nDegree of Injury: {injury}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- return f"Cause of Accident: Unable to analyze due to technical error\nDegree of Injury: {zero_shot_injury}"
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def generate_prediction_ensemble(scenario_text, max_length=300, temperature=0.7):
 
 
199
  if not scenario_text.strip():
200
- return "❌ Please enter a scenario.", "", "", ""
201
 
202
  try:
 
203
  prompt = format_input(scenario_text)
 
 
 
204
  raw_outputs = generate_all_model_outputs(prompt, max_length, temperature)
205
-
 
206
  scenario_only = extract_scenario_from_prompt(prompt)
207
  injury_guess = classify_injury_zero_shot(scenario_only)
208
-
209
- # Extract structured JSON data from raw outputs
210
- structured_data = extract_structured_data_from_outputs(raw_outputs)
211
-
212
- # Use advanced HF model analysis (replacing Gemini)
213
- hf_analysis = analyze_with_advanced_hf_model(raw_outputs, injury_guess, structured_data)
214
-
215
- # Parse the analysis results
216
- match_cause = re.search(r"Cause of Accident\s*:\s*(.+)", hf_analysis)
217
- match_injury = re.search(r"Degree of Injury\s*:\s*(Low|Medium|High)", hf_analysis, re.IGNORECASE)
218
-
219
- cause = match_cause.group(1).strip() if match_cause else "Unable to determine cause"
220
  injury = match_injury.group(1).strip().capitalize() if match_injury else injury_guess
221
-
222
  combined_raw = "\n\n".join(raw_outputs)
223
 
224
- # Format structured data for display
225
- structured_display = json.dumps(structured_data, indent=2) if structured_data else "No structured data found"
226
 
227
- return cause, injury, combined_raw, structured_display
228
-
229
  except Exception as e:
230
- return "❌ Prediction failed.", "", traceback.format_exc(), ""
 
 
 
231
 
232
  def create_interface():
233
- with gr.Blocks(title="Workplace Safety Risk Predictor") as interface:
234
- gr.HTML("""
235
- <h1>🚧 Workplace Safety Risk Prediction Model (Enhanced Ensemble)</h1>
236
- <p>Enter a construction scenario to analyze possible risks. Uses advanced language models for better analysis.</p>
237
- <p><strong>Expected JSON Output Format:</strong></p>
238
- <pre>{"Cause of Accident": "...", "Degree of Injury": "High/Medium/Low", "Hazards": ["...", "..."]}</pre>
239
- <p><strong>Examples:</strong></p>
240
  <ul>
241
- <li>An employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.</li>
242
- <li>A worker fell from scaffolding due to lack of fall protection measures in place.</li>
243
- <li>While operating a crane, the load became unstable and struck a nearby worker.</li>
244
- <li>During welding, flammable vapors ignited due to poor fire safety practices.</li>
 
245
  </ul>
 
 
 
246
  """)
247
 
248
  with gr.Row():
249
  with gr.Column():
250
- scenario_input = gr.Textbox(lines=5, label="Scenario Description")
 
 
 
 
 
251
  gr.Markdown("**Quick Examples:**")
252
  with gr.Row():
253
- ex1 = gr.Button("Solvent Exposure")
254
- ex2 = gr.Button("Fall from Scaffolding")
255
- ex3 = gr.Button("Crane Load Accident")
256
- ex4 = gr.Button("Welding Fire Hazard")
257
- temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Creativity (Temperature)")
258
- max_len = gr.Slider(100, 500, 300, 50, label="Max Response Length")
259
- predict_btn = gr.Button("πŸ” Analyze")
 
 
 
260
 
261
  with gr.Column():
262
- cause_output = gr.Textbox(label="πŸ“ Cause of Accident")
263
- degree_output = gr.Textbox(label="πŸ“ˆ Degree of Injury")
264
- with gr.Accordion("πŸ“Š Extracted Structured Data", open=False):
265
- structured_output = gr.Textbox(label="JSON Data from Models", lines=8)
266
- with gr.Accordion("πŸ“„ Raw Model Outputs", open=False):
267
- raw_output = gr.Textbox(label="Raw Responses", lines=12)
268
-
 
 
 
 
 
 
 
269
  predict_btn.click(
270
  fn=generate_prediction_ensemble,
271
  inputs=[scenario_input, max_len, temperature],
272
- outputs=[cause_output, degree_output, raw_output, structured_output]
273
  )
274
 
275
- ex1.click(fn=lambda: "An employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.", outputs=scenario_input)
276
- ex2.click(fn=lambda: "A worker fell from scaffolding due to lack of fall protection measures in place.", outputs=scenario_input)
277
- ex3.click(fn=lambda: "While operating a crane, the load became unstable and struck a nearby worker.", outputs=scenario_input)
278
- ex4.click(fn=lambda: "During welding, flammable vapors ignited due to poor fire safety practices.", outputs=scenario_input)
 
 
 
 
279
 
280
- gr.HTML("<p style='text-align:center;'>Built with Advanced Transformers + Enhanced Analysis + Gradio</p>")
 
 
 
 
 
281
 
282
  return interface
283
 
284
- print("πŸš€ Starting app...")
 
 
 
285
  if load_models():
 
286
  app = create_interface()
287
  if __name__ == "__main__":
288
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)
289
  else:
290
- print("❌ Failed to load models.")
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from peft import PeftModel
11
 
12
+ # Configuration - Multiple CPU-optimized models
13
  MODEL_PATHS = [
14
  "FrAnKu34t23/Construction_Risk_Prediction_Model_v3"
15
  ]
16
+
17
+ # CPU-friendly base model options
18
+ BASE_MODEL_ID = "microsoft/phi-2" # Best balance of efficiency and capability
19
+ # Alternative: "google/gemma-2b" for even better quality if your CPU can handle it
 
20
 
21
  models = []
22
  tokenizers = []
23
+ model_names = []
24
 
25
+ # CPU-friendly classifier
26
+ injury_classifier = pipeline(
27
+ "zero-shot-classification",
28
+ model="typeform/distilbert-base-uncased-mnli",
29
+ device=-1
 
 
 
30
  )
31
 
32
  def classify_injury_zero_shot(description):
33
  candidate_labels = [
34
  "Low severity injury (minor discomfort or bruise) or unrelevant cases",
35
+ "Medium severity injury (sprain, strain, moderate pain)",
36
  "High severity injury (fracture, major trauma, amputation, fatal)"
37
  ]
38
  label_mapping = {
39
  candidate_labels[0]: "Low",
40
+ candidate_labels[1]: "Medium",
41
  candidate_labels[2]: "High"
42
  }
43
  result = injury_classifier(description, candidate_labels)
44
  return label_mapping[result['labels'][0]]
45
 
46
  def load_models():
47
+ global models, tokenizers, model_names
48
  try:
49
+ for i, path in enumerate(MODEL_PATHS):
50
+ print(f"Loading model {i+1}/{len(MODEL_PATHS)}: {path}")
51
+
52
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
53
  tokenizer.pad_token = tokenizer.eos_token
54
+
55
+ # Load base model with CPU optimization
56
+ base_model = AutoModelForCausalLM.from_pretrained(
57
+ BASE_MODEL_ID,
58
+ torch_dtype=torch.float32,
59
+ device_map=None,
60
+ trust_remote_code=True,
61
+ low_cpu_mem_usage=True # Reduce memory usage
62
+ )
63
+
64
+ # Load PEFT model
65
  model = PeftModel.from_pretrained(base_model, path)
66
  model.eval()
67
+ model = model.to('cpu')
68
+
69
  models.append(model)
70
  tokenizers.append(tokenizer)
71
+ model_names.append(f"Model_{i+1}")
72
+
73
+ print(f"βœ… All {len(models)} models loaded successfully on CPU.")
74
  return True
75
+
76
  except Exception as e:
77
  print(f"❌ Model loading failed: {e}")
78
+ traceback.print_exc()
79
  return False
80
 
81
  def format_input(scenario_text):
 
84
  scenario = ", " + scenario.lstrip(", ")
85
  return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"
86
 
87
+ def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7):
88
+ """Generate output from a single model"""
89
+ try:
90
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
91
+ inputs = {k: v.to('cpu') for k, v in inputs.items()}
92
 
93
  with torch.no_grad():
94
  output = model.generate(
 
99
  top_k=50,
100
  repetition_penalty=1.1,
101
  pad_token_id=tokenizer.pad_token_id,
102
+ eos_token_id=tokenizer.eos_token_id,
103
+ do_sample=True
104
  )
105
 
106
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
107
+ return decoded
108
+
109
+ except Exception as e:
110
+ print(f"Error generating from model: {e}")
111
+ return f"Error: Unable to generate response - {str(e)}"
112
 
113
+ def generate_all_model_outputs(prompt, max_length=300, temperature=0.7):
114
+ """Generate outputs from all loaded models"""
115
+ outputs = []
116
+
117
+ for i, (model, tokenizer) in enumerate(zip(models, tokenizers)):
118
+ print(f"Generating from {model_names[i]}...")
119
+
120
+ raw_output = generate_single_model_output(model, tokenizer, prompt, max_length, temperature)
121
+ formatted_output = f"=== {model_names[i].upper()} RAW RESPONSE START ===\n{raw_output}\n=== {model_names[i].upper()} RAW RESPONSE END ==="
122
+ outputs.append(formatted_output)
123
+
124
+ return outputs
125
 
126
+ def analyze_with_cpu_model(raw_outputs, zero_shot_injury):
127
+ """Use CPU-only model for analysis instead of Gemini Pro"""
128
+
129
+ # Initialize CPU analysis model (using a better model for reasoning)
130
  try:
131
+ analysis_pipeline = pipeline(
132
+ "text-generation",
133
+ model="microsoft/phi-2",
134
+ device=-1,
135
+ torch_dtype=torch.float32,
136
+ trust_remote_code=True
137
+ )
138
  except:
139
+ # Fallback to simpler model if Phi-2 fails
140
+ analysis_pipeline = pipeline(
141
+ "text-generation",
142
+ model="gpt2",
143
+ device=-1
144
+ )
 
 
 
 
 
 
 
 
 
145
 
146
+ # Prepare concise prompt for CPU model
147
+ models_summary = ""
148
+ for i, output in enumerate(raw_outputs):
149
+ # Extract key information from each model output
150
+ clean_output = output.replace("=== RAW RESPONSE START ===", "").replace("=== RAW RESPONSE END ===", "").strip()
151
+ models_summary += f"Model {i+1}: {clean_output[:200]}...\n"
152
 
153
+ prompt = f"""Analyze workplace safety incident based on model predictions:
 
 
 
 
 
 
 
154
 
155
+ {models_summary}
 
156
 
157
+ Classifier suggests: {zero_shot_injury} severity
158
 
159
+ Task: Integrate the model outputs to identify the main cause and injury level.
160
+ Response format:
161
+ Cause of Accident:"""
 
 
 
 
 
 
162
 
163
  try:
164
+ # Generate analysis
165
+ response = analysis_pipeline(
166
+ prompt,
167
+ max_length=len(prompt.split()) + 80,
168
+ temperature=0.3,
169
  do_sample=True,
170
+ pad_token_id=analysis_pipeline.tokenizer.eos_token_id,
171
+ truncation=True
172
  )
173
 
174
  generated_text = response[0]['generated_text']
 
175
  analysis_result = generated_text.replace(prompt, "").strip()
176
 
177
+ # Ensure proper format
178
+ if not analysis_result or len(analysis_result) < 10:
179
+ return perform_rule_based_analysis(raw_outputs, zero_shot_injury)
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Add degree of injury if not present
182
+ if "Degree of Injury:" not in analysis_result:
183
+ analysis_result += f"\nDegree of Injury: {zero_shot_injury}"
184
+
185
+ return f"Cause of Accident: {analysis_result}\nDegree of Injury: {zero_shot_injury}"
186
 
187
  except Exception as e:
188
+ print(f"❌ CPU model analysis failed: {e}")
189
+ return perform_rule_based_analysis(raw_outputs, zero_shot_injury)
190
+
191
+ def perform_rule_based_analysis(raw_outputs, zero_shot_injury):
192
+ """Fallback rule-based analysis when CPU model fails"""
193
+
194
+ # Combine all model outputs
195
+ all_text = " ".join(raw_outputs).lower()
196
+
197
+ # Define safety categories and their indicators
198
+ safety_categories = {
199
+ 'fall_protection': ['fall', 'height', 'scaffolding', 'ladder', 'roof', 'protection'],
200
+ 'chemical_exposure': ['chemical', 'solvent', 'toxic', 'fumes', 'vapor', 'exposure'],
201
+ 'equipment_failure': ['equipment', 'machinery', 'malfunction', 'failure', 'maintenance'],
202
+ 'fire_safety': ['fire', 'ignition', 'flammable', 'welding', 'spark', 'combustible'],
203
+ 'electrical': ['electrical', 'shock', 'current', 'wire', 'power'],
204
+ 'confined_space': ['confined', 'space', 'ventilation', 'oxygen', 'gas']
205
+ }
206
+
207
+ # Score each category
208
+ category_scores = {}
209
+ for category, keywords in safety_categories.items():
210
+ score = sum(1 for keyword in keywords if keyword in all_text)
211
+ if score > 0:
212
+ category_scores[category] = score
213
+
214
+ # Determine primary cause
215
+ if category_scores:
216
+ primary_category = max(category_scores, key=category_scores.get)
217
+
218
+ cause_descriptions = {
219
+ 'fall_protection': "Inadequate fall protection measures leading to worker falling from height",
220
+ 'chemical_exposure': "Unsafe chemical handling without proper protective equipment causing exposure",
221
+ 'equipment_failure': "Equipment malfunction due to inadequate maintenance or safety protocols",
222
+ 'fire_safety': "Fire safety protocol violations resulting in ignition of flammable materials",
223
+ 'electrical': "Electrical safety hazards due to improper handling or faulty equipment",
224
+ 'confined_space': "Confined space entry without proper safety procedures and ventilation"
225
+ }
226
+
227
+ primary_cause = cause_descriptions.get(primary_category, "Multiple safety protocol violations")
228
+
229
+ # Add secondary factors if present
230
+ secondary_factors = [cat for cat, score in category_scores.items()
231
+ if cat != primary_category and score > 0]
232
 
233
+ if secondary_factors:
234
+ primary_cause += f". Contributing factors include {', '.join(secondary_factors[:2])} safety issues"
235
+
236
+ else:
237
+ primary_cause = "Safety incident due to inadequate risk assessment and protocol violations"
238
+
239
+ return f"Cause of Accident: {primary_cause}.\nDegree of Injury: {zero_shot_injury}"
240
+
241
+ def extract_scenario_from_prompt(prompt):
242
+ try:
243
+ return re.sub(r"^.*predict potential hazards and injuries\.\s*", "", prompt)
244
+ except:
245
+ return prompt
246
 
247
  def generate_prediction_ensemble(scenario_text, max_length=300, temperature=0.7):
248
+ """Main prediction function using CPU-only models"""
249
+
250
  if not scenario_text.strip():
251
+ return "❌ Please enter a scenario.", "", ""
252
 
253
  try:
254
+ # Generate prompt
255
  prompt = format_input(scenario_text)
256
+
257
+ # Generate outputs from all models
258
+ print("Generating outputs from all models...")
259
  raw_outputs = generate_all_model_outputs(prompt, max_length, temperature)
260
+
261
+ # Get zero-shot classification
262
  scenario_only = extract_scenario_from_prompt(prompt)
263
  injury_guess = classify_injury_zero_shot(scenario_only)
264
+
265
+ # Use CPU model for analysis instead of Gemini
266
+ print("Analyzing with CPU model...")
267
+ cpu_analysis = analyze_with_cpu_model(raw_outputs, injury_guess)
268
+
269
+ # Parse CPU analysis response
270
+ match_cause = re.search(r"Cause of Accident\s*:\s*(.+?)(?=\nDegree of Injury|$)", cpu_analysis, re.DOTALL)
271
+ match_injury = re.search(r"Degree of Injury\s*:\s*(Low|Medium|High)", cpu_analysis, re.IGNORECASE)
272
+
273
+ cause = match_cause.group(1).strip() if match_cause else "Unable to determine cause from model outputs"
 
 
274
  injury = match_injury.group(1).strip().capitalize() if match_injury else injury_guess
275
+
276
  combined_raw = "\n\n".join(raw_outputs)
277
 
278
+ return cause, injury, combined_raw
 
279
 
 
 
280
  except Exception as e:
281
+ error_msg = f"❌ Prediction failed: {str(e)}"
282
+ print(error_msg)
283
+ traceback.print_exc()
284
+ return error_msg, "", ""
285
 
286
  def create_interface():
287
+ with gr.Blocks(title="Multi-Model Safety Risk Predictor") as interface:
288
+ gr.HTML(f"""
289
+ <h1>🚧 Multi-Model Safety Risk Predictor (CPU-Only)</h1>
290
+ <p><strong>System Overview:</strong></p>
 
 
 
291
  <ul>
292
+ <li>Loads {len(MODEL_PATHS)} specialized safety prediction models</li>
293
+ <li>Each model analyzes the scenario independently</li>
294
+ <li>CPU-only analysis model integrates all results using advanced reasoning</li>
295
+ <li>Handles conflicting predictions through pattern analysis and majority consensus</li>
296
+ <li>Fully optimized for CPU-only Hugging Face Spaces</li>
297
  </ul>
298
+ <p><strong>Models Loaded:</strong> {len(models)} / {len(MODEL_PATHS)}</p>
299
+ <p><strong>Base Model:</strong> {BASE_MODEL_ID}</p>
300
+ <p><strong>Analysis Method:</strong> CPU-Only (No external API calls)</p>
301
  """)
302
 
303
  with gr.Row():
304
  with gr.Column():
305
+ scenario_input = gr.Textbox(
306
+ lines=6,
307
+ label="Construction Scenario Description",
308
+ placeholder="Describe the workplace safety incident or scenario..."
309
+ )
310
+
311
  gr.Markdown("**Quick Examples:**")
312
  with gr.Row():
313
+ ex1 = gr.Button("Chemical Exposure", size="sm")
314
+ ex2 = gr.Button("Fall Hazard", size="sm")
315
+ ex3 = gr.Button("Equipment Malfunction", size="sm")
316
+ ex4 = gr.Button("Fire Incident", size="sm")
317
+
318
+ with gr.Row():
319
+ temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Model Creativity")
320
+ max_len = gr.Slider(100, 400, 300, 50, label="Response Length")
321
+
322
+ predict_btn = gr.Button("πŸ” Analyze with Multi-Model Ensemble", variant="primary")
323
 
324
  with gr.Column():
325
+ cause_output = gr.Textbox(
326
+ label="πŸ“ Integrated Cause Analysis",
327
+ lines=4,
328
+ info="CPU model's integrated analysis of all model outputs"
329
+ )
330
+ degree_output = gr.Textbox(
331
+ label="πŸ“ˆ Degree of Injury",
332
+ info="Based on zero-shot classification + model integration"
333
+ )
334
+
335
+ with gr.Accordion("πŸ“„ Individual Model Outputs", open=False):
336
+ raw_output = gr.Textbox(label="Raw Model Responses", lines=15)
337
+
338
+ # Event handlers
339
  predict_btn.click(
340
  fn=generate_prediction_ensemble,
341
  inputs=[scenario_input, max_len, temperature],
342
+ outputs=[cause_output, degree_output, raw_output]
343
  )
344
 
345
+ # Example scenarios
346
+ ex1.click(fn=lambda: "An employee was working with chemical solvents in a poorly ventilated area without proper respiratory protection. The worker began experiencing dizziness and respiratory distress after prolonged exposure.", outputs=scenario_input)
347
+
348
+ ex2.click(fn=lambda: "A construction worker was installing roofing materials on a steep slope without proper fall protection equipment. The worker lost footing on wet materials and fell.", outputs=scenario_input)
349
+
350
+ ex3.click(fn=lambda: "During routine maintenance, a hydraulic press malfunctioned due to worn seals. The operator's hand was caught when the press unexpectedly activated.", outputs=scenario_input)
351
+
352
+ ex4.click(fn=lambda: "While welding in an area with flammable materials, proper fire safety protocols were not followed. Sparks ignited nearby combustible materials causing a flash fire.", outputs=scenario_input)
353
 
354
+ gr.HTML(f"""
355
+ <div style='text-align:center; margin-top:20px;'>
356
+ <p><strong>System Status:</strong> {len(models)} models loaded | CPU-optimized | No external APIs</p>
357
+ <p><em>Built with Multi-Model Ensemble + CPU Analysis + Gradio</em></p>
358
+ </div>
359
+ """)
360
 
361
  return interface
362
 
363
+ # Initialize and launch
364
+ print("πŸš€ Starting Multi-Model Safety Predictor...")
365
+ print(f"Attempting to load {len(MODEL_PATHS)} models...")
366
+
367
  if load_models():
368
+ print(f"βœ… Successfully loaded {len(models)} models")
369
  app = create_interface()
370
  if __name__ == "__main__":
371
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)
372
  else:
373
+ print("❌ Failed to load models. Please check model paths and system resources.")