FrAnKu34t23 commited on
Commit
823b0ef
Β·
verified Β·
1 Parent(s): 65e487e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -71
app.py CHANGED
@@ -63,118 +63,228 @@ def format_input(scenario_text):
63
 
64
  return formatted_prompt
65
 
66
- return formatted_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def parse_json_response(response_text):
69
- """Extract and parse JSON from model response"""
70
  try:
 
 
 
71
  # First, try to parse the entire response as JSON
72
- if response_text.strip().startswith('{') and response_text.strip().endswith('}'):
73
- return json.loads(response_text.strip())
 
 
 
 
 
 
 
 
74
 
75
- # If that fails, look for JSON pattern in the text
76
- json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
77
- matches = re.findall(json_pattern, response_text, re.DOTALL)
 
 
78
 
79
- for match in matches:
80
- try:
81
- return json.loads(match)
82
- except:
83
- continue
 
 
 
84
 
85
- # If no valid JSON found, return structured error
86
- return {
87
- "Hazards": ["Unable to parse response"],
88
- "Cause of Accident": "Model output parsing failed",
89
- "Degree of Injury": "Unknown",
90
- "raw_response": response_text
91
- }
92
 
93
  except Exception as e:
94
-
95
  return {
96
- "Hazards": [f"Parsing error: {str(e)}"],
97
- "Cause of Accident": "JSON parsing failed",
98
  "Degree of Injury": "Unknown",
99
- "raw_response": response_text
100
-
101
  }
102
 
103
  def generate_prediction(scenario_text, max_length=300, temperature=0.7):
 
 
 
104
  if model is None or tokenizer is None:
105
  return "❌ Model not loaded. Please wait for initialization.", "", "", "", ""
106
 
 
 
 
107
  try:
108
  # Format the input
109
  formatted_prompt = format_input(scenario_text)
110
- full_prompt = f"{formatted_prompt}{tokenizer.eos_token}"
111
 
112
  # Tokenize
113
  inputs = tokenizer(
114
- full_prompt,
115
  return_tensors="pt",
116
  truncation=True,
117
  max_length=512,
 
 
 
 
118
  device = next(model.parameters()).device
119
  inputs = {k: v.to(device) for k, v in inputs.items()}
120
 
121
- # Generate response
122
  with torch.no_grad():
123
  outputs = model.generate(
124
  **inputs,
125
  max_length=len(inputs['input_ids'][0]) + max_length,
126
- temperature=temperature,
127
  do_sample=True,
128
- top_p=0.9,
129
- top_k=50,
130
  pad_token_id=tokenizer.pad_token_id,
131
  eos_token_id=tokenizer.eos_token_id,
132
  num_return_sequences=1,
133
- repetition_penalty=1.1,
134
  early_stopping=True
135
  )
136
 
137
  # Decode response
138
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
139
 
140
  # Extract generated part
141
- input_text = tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=False)
142
 
143
  if full_response.startswith(input_text):
144
  generated_part = full_response[len(input_text):].strip()
145
  else:
146
  generated_part = full_response.strip()
147
 
148
- # Clean up response
149
- if generated_part.startswith(tokenizer.eos_token):
150
- generated_part = generated_part[len(tokenizer.eos_token):].strip()
151
-
152
- if generated_part.endswith(tokenizer.eos_token):
153
- generated_part = generated_part[:-len(tokenizer.eos_token)].strip()
154
 
155
  # Parse the JSON response
156
  parsed_response = parse_json_response(generated_part)
157
 
158
- # Extract individual components
159
  hazards = parsed_response.get("Hazards", [])
160
- cause = parsed_response.get("Cause of Accident", "Not specified")
161
- degree = parsed_response.get("Degree of Injury", "Not specified")
 
 
 
162
 
163
  # Format hazards for display
164
- hazards_display = ", ".join(hazards) if isinstance(hazards, list) else str(hazards)
165
-
 
 
 
166
  # Create formatted output
167
- formatted_output = json.dumps(parsed_response, indent=2, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
168
 
169
  return hazards_display, cause, degree, formatted_output, generated_part
170
 
171
  except Exception as e:
172
  error_msg = f"❌ Error generating prediction: {str(e)}"
173
- return error_msg, "", "", "", ""
 
174
 
175
  def create_interface():
176
  """Create the Gradio interface"""
177
- # Custom CSS for better styling
 
178
  css = """
179
  .gradio-container {
180
  font-family: 'Arial', sans-serif;
@@ -209,103 +319,168 @@ def create_interface():
209
  """
210
 
211
  with gr.Blocks(css=css, title="Workplace Safety Risk Predictor") as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  temperature = gr.Slider(
213
  minimum=0.1,
214
  maximum=1.0,
215
- value=0.7,
216
  step=0.1,
217
  label="Creativity (Temperature)",
218
- info="Higher values = more creative responses"
219
  )
220
 
221
  with gr.Column():
222
  max_length = gr.Slider(
223
  minimum=100,
224
  maximum=500,
225
- value=300,
226
  step=50,
227
  label="Max Response Length",
228
  info="Maximum length of generated response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  with gr.Column():
230
  hazards_output = gr.Textbox(
231
  label="🚨 Identified Hazards",
232
- info="Potential hazards identified in the scenario"
233
-
234
  )
235
 
236
  cause_output = gr.Textbox(
237
  label="πŸ” Cause of Accident",
238
- info="Primary cause classification"
239
-
240
  )
241
 
242
  degree_output = gr.Textbox(
243
  label="πŸ“ˆ Degree of Injury",
244
- info="Severity assessment"
245
-
246
  )
247
 
248
  with gr.Accordion("πŸ“‹ Detailed JSON Output", open=False):
 
 
 
 
 
249
  with gr.Accordion("πŸ” Raw Model Output", open=False):
250
  raw_output = gr.Textbox(
251
  label="Raw Response",
252
- lines=3,
253
- info="Unprocessed model output"
254
  )
255
 
256
  # Example scenarios
257
  gr.HTML("<h3>πŸ’‘ Example Scenarios</h3>")
258
 
259
  with gr.Row():
260
- example1 = gr.Button("Power Press Accident")
261
- example2 = gr.Button("Fall from Ladder")
262
- example3 = gr.Button("Chemical Exposure")
263
- example4 = gr.Button("Lifting Injury")
264
 
265
  # Event handlers
266
  predict_btn.click(
 
 
267
  outputs=[hazards_output, cause_output, degree_output, json_output, raw_output]
268
  )
269
 
270
- # Example scenarios
271
  example1.click(
272
- lambda: "an employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation. The employee's fingers were amputated.",
273
  outputs=scenario_input
274
  )
275
 
276
  example2.click(
277
- lambda: "an employee was using a ladder to access high shelves. The ladder was not properly secured and the employee fell from a height of 8 feet, resulting in head injuries.",
278
  outputs=scenario_input
279
  )
280
 
281
  example3.click(
282
- lambda: "an employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.",
283
  outputs=scenario_input
284
  )
285
 
286
  example4.click(
287
- lambda: "an employee was manually lifting heavy boxes weighing over 50 pounds without proper lifting technique or mechanical aids. The employee strained their back.",
288
  outputs=scenario_input
289
  )
290
 
291
  gr.HTML("""
292
  <div style="text-align: center; margin-top: 30px; color: #666;">
293
  <p>Built with ❀️ using Hugging Face Transformers and Gradio</p>
294
- <p>Model: <a href="https://huggingface.co/FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v3">Construction_Mistral_Risk_Prediction_Model_v3</a></p>
295
  </div>
296
  """)
297
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  app.launch(
299
  server_name="0.0.0.0",
300
  server_port=7860,
301
- share=True
302
-
303
  )
304
  else:
305
  print("❌ Failed to load model. App cannot start.")
306
  # Create a simple error interface
307
  with gr.Blocks() as error_app:
308
- gr.HTML("<h1>❌ Model Loading Failed</h1><p>Unable to load the safety prediction model.</p>")
 
 
 
 
 
 
 
 
 
 
309
 
310
  if __name__ == "__main__":
311
  error_app.launch()
 
63
 
64
  return formatted_prompt
65
 
66
+ def clean_json_string(text):
67
+ """Clean and fix common JSON formatting issues"""
68
+ # Remove any leading/trailing whitespace
69
+ text = text.strip()
70
+
71
+ # Fix common JSON issues
72
+ # Replace single quotes with double quotes (but be careful about apostrophes)
73
+ text = re.sub(r"'([^']*)':", r'"\1":', text) # Fix keys
74
+ text = re.sub(r":\s*'([^']*)'", r': "\1"', text) # Fix string values
75
+
76
+ # Fix missing quotes around keys
77
+ text = re.sub(r'([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:', r'\1"\2":', text)
78
+
79
+ # Fix trailing commas
80
+ text = re.sub(r',(\s*[}\]])', r'\1', text)
81
+
82
+ # Ensure arrays are properly formatted
83
+ text = re.sub(r'\[\s*([^[\]]+)\s*\]', lambda m: '[' + ', '.join([f'"{item.strip()}"' if not item.strip().startswith('"') else item.strip() for item in m.group(1).split(',')]) + ']', text)
84
+
85
+ return text
86
+
87
+ def extract_structured_info(text):
88
+ """Extract structured information even if JSON parsing fails"""
89
+ result = {
90
+ "Hazards": [],
91
+ "Cause of Accident": "Not specified",
92
+ "Degree of Injury": "Not specified"
93
+ }
94
+
95
+ # Try to extract hazards
96
+ hazard_patterns = [
97
+ r"Hazards[\"']?\s*:\s*\[([^\]]+)\]",
98
+ r"Hazards[\"']?\s*:\s*([^,}]+)",
99
+ r"MATERIAL HANDLING|FALL PROTECTION|VALVE EXPLOSION|NOMA"
100
+ ]
101
+
102
+ for pattern in hazard_patterns:
103
+ matches = re.findall(pattern, text, re.IGNORECASE)
104
+ if matches:
105
+ if isinstance(matches[0], str):
106
+ # Clean and split hazards
107
+ hazards = [h.strip().strip('"\'') for h in matches[0].split(',')]
108
+ result["Hazards"] = [h for h in hazards if h and h != 'NOMA']
109
+ break
110
+
111
+ # Extract cause
112
+ cause_patterns = [
113
+ r"Cause of Accident[\"']?\s*:\s*[\"']([^\"']+)[\"']",
114
+ r"Other caused by ([^,}]+)",
115
+ r"Cause[\"']?\s*:\s*[\"']([^\"']+)[\"']"
116
+ ]
117
+
118
+ for pattern in cause_patterns:
119
+ match = re.search(pattern, text, re.IGNORECASE)
120
+ if match:
121
+ result["Cause of Accident"] = match.group(1).strip()
122
+ break
123
+
124
+ # Extract degree of injury
125
+ degree_patterns = [
126
+ r"Degree of Injury[\"']?\s*:\s*[\"']([^\"']+)[\"']",
127
+ r"High|Medium|Low|Severe|Minor|Fatal",
128
+ r"Injury[\"']?\s*:\s*[\"']([^\"']+)[\"']"
129
+ ]
130
+
131
+ for pattern in degree_patterns:
132
+ match = re.search(pattern, text, re.IGNORECASE)
133
+ if match:
134
+ result["Degree of Injury"] = match.group(1).strip() if hasattr(match, 'group') else match.group(0).strip()
135
+ break
136
+
137
+ return result
138
 
139
  def parse_json_response(response_text):
140
+ """Extract and parse JSON from model response with better error handling"""
141
  try:
142
+ # Clean the response text
143
+ cleaned_text = response_text.strip()
144
+
145
  # First, try to parse the entire response as JSON
146
+ if cleaned_text.startswith('{') and cleaned_text.endswith('}'):
147
+ try:
148
+ return json.loads(cleaned_text)
149
+ except json.JSONDecodeError:
150
+ # Try cleaning the JSON
151
+ cleaned_json = clean_json_string(cleaned_text)
152
+ try:
153
+ return json.loads(cleaned_json)
154
+ except json.JSONDecodeError:
155
+ pass
156
 
157
+ # Look for JSON-like patterns in the text
158
+ json_patterns = [
159
+ r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}',
160
+ r'\{.*?\}',
161
+ ]
162
 
163
+ for pattern in json_patterns:
164
+ matches = re.findall(pattern, response_text, re.DOTALL)
165
+ for match in matches:
166
+ try:
167
+ cleaned_match = clean_json_string(match)
168
+ return json.loads(cleaned_match)
169
+ except json.JSONDecodeError:
170
+ continue
171
 
172
+ # If JSON parsing completely fails, extract structured info
173
+ structured_info = extract_structured_info(response_text)
174
+ structured_info["raw_response"] = response_text
175
+ structured_info["parsing_method"] = "regex_extraction"
176
+
177
+ return structured_info
 
178
 
179
  except Exception as e:
180
+ # Last resort: return basic structure with error info
181
  return {
182
+ "Hazards": ["Parsing failed - check raw output"],
183
+ "Cause of Accident": f"Error: {str(e)[:100]}...",
184
  "Degree of Injury": "Unknown",
185
+ "raw_response": response_text,
186
+ "error": str(e)
187
  }
188
 
189
  def generate_prediction(scenario_text, max_length=300, temperature=0.7):
190
+ """Generate workplace safety prediction"""
191
+ global model, tokenizer
192
+
193
  if model is None or tokenizer is None:
194
  return "❌ Model not loaded. Please wait for initialization.", "", "", "", ""
195
 
196
+ if not scenario_text.strip():
197
+ return "❌ Please enter a workplace scenario to analyze.", "", "", "", ""
198
+
199
  try:
200
  # Format the input
201
  formatted_prompt = format_input(scenario_text)
 
202
 
203
  # Tokenize
204
  inputs = tokenizer(
205
+ formatted_prompt,
206
  return_tensors="pt",
207
  truncation=True,
208
  max_length=512,
209
+ padding=False
210
+ )
211
+
212
+ # Move to same device as model
213
  device = next(model.parameters()).device
214
  inputs = {k: v.to(device) for k, v in inputs.items()}
215
 
216
+ # Generate response with more conservative settings for better JSON
217
  with torch.no_grad():
218
  outputs = model.generate(
219
  **inputs,
220
  max_length=len(inputs['input_ids'][0]) + max_length,
221
+ temperature=max(0.3, temperature), # Lower temperature for more consistent output
222
  do_sample=True,
223
+ top_p=0.8, # Slightly more conservative
224
+ top_k=40, # Reduced for consistency
225
  pad_token_id=tokenizer.pad_token_id,
226
  eos_token_id=tokenizer.eos_token_id,
227
  num_return_sequences=1,
228
+ repetition_penalty=1.2, # Slightly higher to avoid repetition
229
  early_stopping=True
230
  )
231
 
232
  # Decode response
233
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
234
 
235
  # Extract generated part
236
+ input_text = tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)
237
 
238
  if full_response.startswith(input_text):
239
  generated_part = full_response[len(input_text):].strip()
240
  else:
241
  generated_part = full_response.strip()
242
 
243
+ # Clean up common artifacts
244
+ generated_part = re.sub(r'^[,\s]*', '', generated_part) # Remove leading commas/spaces
245
+ generated_part = re.sub(r'<[^>]*>', '', generated_part) # Remove any HTML-like tags
 
 
 
246
 
247
  # Parse the JSON response
248
  parsed_response = parse_json_response(generated_part)
249
 
250
+ # Extract individual components with better defaults
251
  hazards = parsed_response.get("Hazards", [])
252
+ if not hazards or (isinstance(hazards, list) and len(hazards) == 0):
253
+ hazards = ["No specific hazards identified"]
254
+
255
+ cause = parsed_response.get("Cause of Accident", "Analysis incomplete")
256
+ degree = parsed_response.get("Degree of Injury", "Assessment needed")
257
 
258
  # Format hazards for display
259
+ if isinstance(hazards, list):
260
+ hazards_display = ", ".join(str(h) for h in hazards if h)
261
+ else:
262
+ hazards_display = str(hazards)
263
+
264
  # Create formatted output
265
+ display_response = {
266
+ "Hazards": hazards,
267
+ "Cause of Accident": cause,
268
+ "Degree of Injury": degree
269
+ }
270
+
271
+ # Add metadata if available
272
+ if "parsing_method" in parsed_response:
273
+ display_response["Parsing Method"] = parsed_response["parsing_method"]
274
+
275
+ formatted_output = json.dumps(display_response, indent=2, ensure_ascii=False)
276
 
277
  return hazards_display, cause, degree, formatted_output, generated_part
278
 
279
  except Exception as e:
280
  error_msg = f"❌ Error generating prediction: {str(e)}"
281
+ print(f"Generation error: {e}") # For debugging
282
+ return error_msg, "", "", "", str(e)
283
 
284
  def create_interface():
285
  """Create the Gradio interface"""
286
+
287
+ # Custom CSS for better styling
288
  css = """
289
  .gradio-container {
290
  font-family: 'Arial', sans-serif;
 
319
  """
320
 
321
  with gr.Blocks(css=css, title="Workplace Safety Risk Predictor") as interface:
322
+
323
+ gr.HTML("""
324
+ <div class="header">
325
+ <h1>🚧 Workplace Safety Risk Prediction Model</h1>
326
+ <p>Analyze workplace scenarios to identify potential hazards, causes, and injury severity</p>
327
+ </div>
328
+ """)
329
+
330
+ with gr.Row():
331
+ with gr.Column(scale=2):
332
+ gr.HTML("<h3>πŸ“ Enter Workplace Scenario</h3>")
333
+
334
+ scenario_input = gr.Textbox(
335
+ lines=5,
336
+ placeholder="Example: an employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation...",
337
+ label="Workplace Incident Description",
338
+ info="Describe the workplace scenario you want to analyze"
339
+ )
340
+
341
+ with gr.Row():
342
+ with gr.Column():
343
  temperature = gr.Slider(
344
  minimum=0.1,
345
  maximum=1.0,
346
+ value=0.5, # Lower default for more consistent output
347
  step=0.1,
348
  label="Creativity (Temperature)",
349
+ info="Lower values = more consistent responses"
350
  )
351
 
352
  with gr.Column():
353
  max_length = gr.Slider(
354
  minimum=100,
355
  maximum=500,
356
+ value=250, # Slightly lower default
357
  step=50,
358
  label="Max Response Length",
359
  info="Maximum length of generated response"
360
+ )
361
+
362
+ predict_btn = gr.Button("πŸ” Analyze Scenario", variant="primary", size="lg")
363
+
364
+ gr.HTML("""
365
+ <div class="warning-box">
366
+ <strong>⚠️ Note:</strong> This is an AI model for educational purposes.
367
+ Always consult safety professionals for real workplace safety assessments.
368
+ </div>
369
+ """)
370
+
371
+ with gr.Column(scale=2):
372
+ gr.HTML("<h3>πŸ“Š Analysis Results</h3>")
373
+
374
+ with gr.Row():
375
  with gr.Column():
376
  hazards_output = gr.Textbox(
377
  label="🚨 Identified Hazards",
378
+ info="Potential hazards identified in the scenario",
379
+ interactive=False
380
  )
381
 
382
  cause_output = gr.Textbox(
383
  label="πŸ” Cause of Accident",
384
+ info="Primary cause classification",
385
+ interactive=False
386
  )
387
 
388
  degree_output = gr.Textbox(
389
  label="πŸ“ˆ Degree of Injury",
390
+ info="Severity assessment",
391
+ interactive=False
392
  )
393
 
394
  with gr.Accordion("πŸ“‹ Detailed JSON Output", open=False):
395
+ json_output = gr.Code(
396
+ label="Structured Response",
397
+ language="json"
398
+ )
399
+
400
  with gr.Accordion("πŸ” Raw Model Output", open=False):
401
  raw_output = gr.Textbox(
402
  label="Raw Response",
403
+ lines=5,
404
+ info="Unprocessed model output for debugging"
405
  )
406
 
407
  # Example scenarios
408
  gr.HTML("<h3>πŸ’‘ Example Scenarios</h3>")
409
 
410
  with gr.Row():
411
+ example1 = gr.Button("Power Press Accident", size="sm")
412
+ example2 = gr.Button("Fall from Ladder", size="sm")
413
+ example3 = gr.Button("Chemical Exposure", size="sm")
414
+ example4 = gr.Button("Lifting Injury", size="sm")
415
 
416
  # Event handlers
417
  predict_btn.click(
418
+ fn=generate_prediction,
419
+ inputs=[scenario_input, max_length, temperature],
420
  outputs=[hazards_output, cause_output, degree_output, json_output, raw_output]
421
  )
422
 
423
+ # Example scenarios with better formatting
424
  example1.click(
425
+ lambda: "An employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation. The employee's fingers were amputated.",
426
  outputs=scenario_input
427
  )
428
 
429
  example2.click(
430
+ lambda: "An employee was using a ladder to access high shelves. The ladder was not properly secured and the employee fell from a height of 8 feet, resulting in head injuries.",
431
  outputs=scenario_input
432
  )
433
 
434
  example3.click(
435
+ lambda: "An employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.",
436
  outputs=scenario_input
437
  )
438
 
439
  example4.click(
440
+ lambda: "An employee was manually lifting heavy boxes weighing over 50 pounds without proper lifting technique or mechanical aids. The employee strained their back.",
441
  outputs=scenario_input
442
  )
443
 
444
  gr.HTML("""
445
  <div style="text-align: center; margin-top: 30px; color: #666;">
446
  <p>Built with ❀️ using Hugging Face Transformers and Gradio</p>
447
+ <p>Model: <a href="https://huggingface.co/FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v3" target="_blank">Construction_Mistral_Risk_Prediction_Model_v3</a></p>
448
  </div>
449
  """)
450
 
451
+ return interface
452
+
453
+ # Initialize the model when the app starts
454
+ print("πŸš€ Initializing Workplace Safety Risk Prediction App...")
455
+ model_loaded = load_model()
456
+
457
+ if model_loaded:
458
+ print("βœ… App ready!")
459
+ # Create and launch the interface
460
+ app = create_interface()
461
+
462
+ if __name__ == "__main__":
463
  app.launch(
464
  server_name="0.0.0.0",
465
  server_port=7860,
466
+ share=True,
467
+ show_error=True # Better error display
468
  )
469
  else:
470
  print("❌ Failed to load model. App cannot start.")
471
  # Create a simple error interface
472
  with gr.Blocks() as error_app:
473
+ gr.HTML("""
474
+ <div class="error-box">
475
+ <h1>❌ Model Loading Failed</h1>
476
+ <p>Unable to load the safety prediction model. Please check:</p>
477
+ <ul>
478
+ <li>Internet connection for model download</li>
479
+ <li>Available system memory</li>
480
+ <li>Model repository accessibility</li>
481
+ </ul>
482
+ </div>
483
+ """)
484
 
485
  if __name__ == "__main__":
486
  error_app.launch()