FrAnKu34t23 commited on
Commit
23f27a5
Β·
verified Β·
1 Parent(s): 74806e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -235
app.py CHANGED
@@ -2,28 +2,28 @@ import gradio as gr
2
  import torch
3
  import re
4
  import traceback
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- from peft import PeftModel
7
- import ast
8
  import json
9
  import warnings
10
  warnings.filterwarnings("ignore")
11
  import os
12
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
 
 
13
 
14
  # Configuration
 
 
 
15
  BASE_MODEL_ID = "distilgpt2"
16
- LORA_MODEL_PATH = "FrAnKu34t23/Construction_Risk_Prediction_Model_v2"
17
 
18
- model = None
19
- tokenizer = None
20
 
21
- # Load once at startup
22
  injury_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
23
 
24
  def classify_injury_zero_shot(description):
25
  candidate_labels = [
26
- "Low severity injury (minor discomfort or bruise) or unrelevan cases",
27
  "Medium severity injury (sprain, strain, moderate pain)",
28
  "High severity injury (fracture, major trauma, amputation, fatal)"
29
  ]
@@ -32,28 +32,24 @@ def classify_injury_zero_shot(description):
32
  candidate_labels[1]: "Medium",
33
  candidate_labels[2]: "High"
34
  }
35
-
36
  result = injury_classifier(description, candidate_labels)
37
- top_label = result["labels"][0]
38
-
39
- for label, score in zip(result['labels'], result['scores']):
40
- print(f"{label}: {score:.2f}")
41
-
42
- return label_mapping[top_label]
43
-
44
- def load_model():
45
- global model, tokenizer
46
  try:
47
- print("πŸ”„ Loading base model and tokenizer...")
48
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
49
- tokenizer.pad_token = tokenizer.eos_token
50
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
51
- model = PeftModel.from_pretrained(base_model, LORA_MODEL_PATH)
52
- model.eval()
53
- print("βœ… Model loaded successfully!")
 
 
54
  return True
55
  except Exception as e:
56
- print(f"❌ Error loading model: {e}")
57
  return False
58
 
59
  def format_input(scenario_text):
@@ -62,113 +58,17 @@ def format_input(scenario_text):
62
  scenario = ", " + scenario.lstrip(", ")
63
  return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"
64
 
65
- def clean_raw_json_string(raw_text):
66
- """Clean malformed quotes and characters from model output before parsing."""
67
- # Normalize bad quotes
68
- cleaned = raw_text.replace("β€˜", "'").replace("’", "'")
69
- cleaned = cleaned.replace("β€œ", '"').replace("”", '"')
70
- cleaned = cleaned.replace("''", '"').replace("``", '"').replace("†", "")
71
-
72
- # Fix common errors: smart quotes, double single quotes, etc.
73
- cleaned = re.sub(r'([{\[,])\s*"', r'\1 "', cleaned)
74
- cleaned = re.sub(r'"\s*([}\],])', r'" \1', cleaned)
75
-
76
- return cleaned
77
-
78
- def extract_json_object(text):
79
- """Extract and parse the first valid JSON object from text, including malformed hazard list recovery."""
80
- pattern = r'\{(?:[^{}]|"[^"]*")*\}'
81
- matches = re.findall(pattern, text, re.DOTALL)
82
-
83
- for match in matches:
84
- try:
85
- cleaned = clean_raw_json_string(match)
86
-
87
- # Detect and collect any ["..."] list fragments (typically malformed hazards)
88
- hazard_items = re.findall(r'\["([^"]+)"\]', cleaned)
89
-
90
- # Remove malformed hazard list fragments like: ["Hazards"], ["Chemicals"]
91
- cleaned = re.sub(r'(\["[^"]+"\]\s*,?\s*)+', '', cleaned)
92
-
93
- # If Hazards key is missing and we collected items, add it
94
- if hazard_items and "Hazards" not in cleaned:
95
- cleaned = cleaned.rstrip('} \n\t,')
96
- cleaned += ', "Hazards": ' + json.dumps(hazard_items) + '}'
97
-
98
- # Attempt to parse
99
- parsed = json.loads(cleaned)
100
- if isinstance(parsed, dict):
101
- return parsed
102
- except Exception as e:
103
- print(f"⚠️ extract_json_object failed: {e}")
104
- continue
105
- return None
106
-
107
- def extract_fields(text):
108
- def clean_text(t):
109
- t = t.replace("β€˜", "'").replace("’", "'").replace("β€œ", '"').replace("”", '"')
110
- t = t.replace("''", '"').replace("``", '"').replace("†", "").replace("Β΄", "")
111
- t = re.sub(r"[^\x00-\x7F]+", "", t)
112
- return t
113
-
114
- cleaned = clean_text(text)
115
-
116
- cause = "Unknown"
117
- injury = "Unknown"
118
- hazards = []
119
-
120
- # Extract cause
121
- match = re.search(r'"?Cause of Accident"?\s*:\s*"([^"]+)"', cleaned, re.IGNORECASE)
122
- if match:
123
- cause = match.group(1).strip()
124
-
125
- # Use zero-shot classifier always for injury
126
- try:
127
- injury = classify_injury_zero_shot(cleaned)
128
- except:
129
- injury = "Unknown"
130
-
131
- # Extract Hazards
132
- match = re.search(r'"?Hazards"?\s*:\s*(\[[^\]]+\])', cleaned, re.IGNORECASE)
133
- if match:
134
- try:
135
- hazards_raw = clean_text(match.group(1))
136
- if not hazards_raw.strip().startswith("["):
137
- raise ValueError("Not a list")
138
- hazards = ast.literal_eval(hazards_raw)
139
- hazards = [str(h).strip().strip('"').strip("'") for h in hazards]
140
- except Exception as e:
141
- print("⚠️ Hazard parsing failed:", e)
142
- hazards = []
143
-
144
- structured = {
145
- "Hazards": hazards,
146
- "Cause of Accident": cause,
147
- "Degree of Injury": injury
148
- }
149
-
150
- return hazards, cause, injury, json.dumps(structured, indent=2)
151
-
152
- def generate_prediction(scenario_text, max_length=300, temperature=0.7):
153
- global model, tokenizer
154
- if model is None or tokenizer is None:
155
- return "❌ Model not loaded.", "", "", "", ""
156
-
157
- if not scenario_text.strip():
158
- return "❌ Please enter a scenario.", "", "", "", ""
159
-
160
- try:
161
- prompt = format_input(scenario_text)
162
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
163
- device = next(model.parameters()).device
164
- inputs = {k: v.to(device) for k, v in inputs.items()}
165
 
166
  with torch.no_grad():
167
  output = model.generate(
168
  **inputs,
169
  max_length=inputs["input_ids"].shape[1] + max_length,
170
  temperature=temperature,
171
- do_sample=True,
172
  top_p=0.9,
173
  top_k=50,
174
  repetition_penalty=1.1,
@@ -176,130 +76,107 @@ def generate_prediction(scenario_text, max_length=300, temperature=0.7):
176
  eos_token_id=tokenizer.eos_token_id
177
  )
178
 
179
- full_output = tokenizer.decode(output[0], skip_special_tokens=True)
180
- index = full_output.rfind("Based on the situation")
181
- generated = full_output[index:].strip() if index != -1 else full_output.strip()
182
-
183
- json_obj = extract_json_object(generated)
184
- if json_obj:
185
- cause = json_obj.get("Cause of Accident", "Unknown")
186
- injury = json_obj.get("Degree of Injury", "Unknown")
187
- hazards = json_obj.get("Hazards", [])
188
- structured_json = json.dumps(json_obj, indent=2)
189
- else:
190
- hazards, cause, injury, structured_json = extract_fields(generated)
191
-
192
- hazards_display = ", ".join(hazards) if isinstance(hazards, list) else str(hazards)
193
- return hazards_display, cause, injury, structured_json, f"=== RAW RESPONSE START ===\n{generated}\n=== RAW RESPONSE END ==="
194
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  except Exception as e:
196
- return f"❌ Error during prediction: {str(e)}", "", "", "", traceback.format_exc()
197
 
198
  def create_interface():
199
- css = """
200
- .gradio-container {
201
- font-family: 'Arial', sans-serif;
202
- }
203
- .header {
204
- text-align: center;
205
- margin-bottom: 30px;
206
- }
207
- .warning-box {
208
- background-color: #fff3cd;
209
- border: 1px solid #ffeaa7;
210
- border-radius: 5px;
211
- padding: 15px;
212
- margin: 10px 0;
213
- }
214
- .error-box {
215
- background-color: #f8d7da;
216
- border: 1px solid #f5c6cb;
217
- border-radius: 5px;
218
- padding: 15px;
219
- margin: 10px 0;
220
- color: #721c24;
221
- }
222
- """
223
-
224
- with gr.Blocks(css=css, title="Workplace Safety Risk Predictor") as interface:
225
  gr.HTML("""
226
- <div class="header">
227
- <h1>🚧 Workplace Safety Risk Prediction Model</h1>
228
- <p>Analyze workplace scenarios to identify potential hazards, causes, and injury severity</p>
229
- </div>
230
  """)
231
 
232
  with gr.Row():
233
- with gr.Column(scale=2):
234
- scenario_input = gr.Textbox(
235
- lines=5,
236
- placeholder="e.g. During welding, flammable gas ignited, causing explosion...",
237
- label="Workplace Incident Description"
238
- )
239
- temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Creativity (Temperature)")
240
- max_length = gr.Slider(100, 500, value=300, step=50, label="Max Response Length")
241
- predict_btn = gr.Button("πŸ” Analyze Scenario", variant="primary")
242
-
243
- gr.HTML("""
244
- <div class="warning-box">
245
- <strong>⚠️ Note:</strong> This tool is experimental. Consult safety experts for actual workplace assessments.
246
- </div>
247
- """)
248
-
249
- with gr.Column(scale=2):
250
- hazards_output = gr.Textbox(label="🚨 Identified Hazards")
251
- cause_output = gr.Textbox(label="πŸ” Cause of Accident")
252
  degree_output = gr.Textbox(label="πŸ“ˆ Degree of Injury")
253
-
254
- with gr.Accordion("πŸ“‹ Structured Output", open=False):
255
- json_output = gr.Code(label="Extracted Info", language="json")
256
-
257
- with gr.Accordion("πŸ” Raw Model Output", open=False):
258
- raw_output = gr.Textbox(label="Raw Text", lines=5)
259
-
260
- # Example Buttons
261
- gr.HTML("<h3>πŸ’‘ Example Scenarios</h3>")
262
- with gr.Row():
263
- example1 = gr.Button("Power Press Accident")
264
- example2 = gr.Button("Fall from Ladder")
265
- example3 = gr.Button("Chemical Exposure")
266
- example4 = gr.Button("Lifting Injury")
267
 
268
  predict_btn.click(
269
- fn=generate_prediction,
270
- inputs=[scenario_input, max_length, temperature],
271
- outputs=[hazards_output, cause_output, degree_output, json_output, raw_output]
272
  )
273
 
274
- example1.click(
275
- lambda: "An employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation. The employee's fingers were amputated.",
276
- outputs=scenario_input
277
- )
278
- example2.click(
279
- lambda: "An employee was using a ladder to access high shelves. The ladder was not properly secured and the employee fell from a height of 8 feet, resulting in head injuries.",
280
- outputs=scenario_input
281
- )
282
- example3.click(
283
- lambda: "An employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.",
284
- outputs=scenario_input
285
- )
286
- example4.click(
287
- lambda: "An employee was manually lifting heavy boxes weighing over 50 pounds without proper lifting technique or mechanical aids. The employee strained their back.",
288
- outputs=scenario_input
289
- )
290
-
291
- gr.HTML("""
292
- <div style="text-align: center; margin-top: 30px; color: #666;">
293
- <p>Built with ❀️ using Hugging Face Transformers and Gradio</p>
294
- </div>
295
- """)
296
 
297
  return interface
298
 
299
- print("πŸš€ Launching App...")
300
- if load_model():
301
  app = create_interface()
302
  if __name__ == "__main__":
303
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)
304
  else:
305
- print("❌ Could not load model.")
 
2
  import torch
3
  import re
4
  import traceback
 
 
 
5
  import json
6
  import warnings
7
  warnings.filterwarnings("ignore")
8
  import os
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
+ from peft import PeftModel
11
+ from google import genai
12
 
13
  # Configuration
14
+ MODEL_PATHS = [
15
+ "FrAnKu34t23/Construction_Risk_Prediction_Model_v3"
16
+ ]
17
  BASE_MODEL_ID = "distilgpt2"
 
18
 
19
+ models = []
20
+ tokenizers = []
21
 
 
22
  injury_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
23
 
24
  def classify_injury_zero_shot(description):
25
  candidate_labels = [
26
+ "Low severity injury (minor discomfort or bruise) or unrelevant cases",
27
  "Medium severity injury (sprain, strain, moderate pain)",
28
  "High severity injury (fracture, major trauma, amputation, fatal)"
29
  ]
 
32
  candidate_labels[1]: "Medium",
33
  candidate_labels[2]: "High"
34
  }
 
35
  result = injury_classifier(description, candidate_labels)
36
+ return label_mapping[result['labels'][0]]
37
+
38
+ def load_models():
39
+ global models, tokenizers
 
 
 
 
 
40
  try:
41
+ for path in MODEL_PATHS:
42
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+ base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
45
+ model = PeftModel.from_pretrained(base_model, path)
46
+ model.eval()
47
+ models.append(model)
48
+ tokenizers.append(tokenizer)
49
+ print("βœ… All models loaded.")
50
  return True
51
  except Exception as e:
52
+ print(f"❌ Model loading failed: {e}")
53
  return False
54
 
55
  def format_input(scenario_text):
 
58
  scenario = ", " + scenario.lstrip(", ")
59
  return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"
60
 
61
+ def generate_all_model_outputs(prompt, max_length=300, temperature=0.7):
62
+ outputs = []
63
+ for model, tokenizer in zip(models, tokenizers):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
65
+ inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items()}
 
66
 
67
  with torch.no_grad():
68
  output = model.generate(
69
  **inputs,
70
  max_length=inputs["input_ids"].shape[1] + max_length,
71
  temperature=temperature,
 
72
  top_p=0.9,
73
  top_k=50,
74
  repetition_penalty=1.1,
 
76
  eos_token_id=tokenizer.eos_token_id
77
  )
78
 
79
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
80
+ outputs.append(f"=== RAW RESPONSE START ===\n{decoded}\n=== RAW RESPONSE END ===")
81
+ return outputs
82
+
83
+ def extract_scenario_from_prompt(prompt):
84
+ try:
85
+ return re.sub(r"^.*predict potential hazards and injuries\.\s*", "", prompt)
86
+ except:
87
+ return prompt
88
+
89
+ def call_gemini_pro(raw_outputs, zero_shot_injury):
90
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
91
+ chat = client.chats.create(model="gemini-2.0-flash")
92
+
93
+ prompt = f"""
94
+ You are a workplace safety analyst. Below are raw text outputs from three different AI models analyzing the same construction scenario.
95
+
96
+ Your tasks:
97
+ - Compare and merge the model outputs.
98
+ - Summarize the most plausible cause of accident in natural language.
99
+ - Infer the degree of injury by considering all outputs and a classifier suggestion.
100
+
101
+ Classifier prediction for Degree of Injury: {zero_shot_injury}
102
+
103
+ Model Outputs:
104
+ {raw_outputs[0]}
105
+
106
+ {raw_outputs[1]}
107
+
108
+ {raw_outputs[2]}
109
+
110
+ Respond in this format:
111
+ Cause of Accident: <sentence>
112
+ Degree of Injury: <Low / Medium / High>
113
+ """
114
+ try:
115
+ response = chat.send_message(prompt)
116
+ return response.text.strip()
117
+ except Exception as e:
118
+ print("❌ Gemini Pro API call failed:", e)
119
+ return "Cause of Accident: Unknown\nDegree of Injury: Unknown"
120
+
121
+ def generate_prediction_ensemble(scenario_text, max_length=300, temperature=0.7):
122
+ if not scenario_text.strip():
123
+ return "❌ Please enter a scenario.", "", ""
124
+
125
+ try:
126
+ prompt = format_input(scenario_text)
127
+ raw_outputs = generate_all_model_outputs(prompt, max_length, temperature)
128
+
129
+ scenario_only = extract_scenario_from_prompt(prompt)
130
+ injury_guess = classify_injury_zero_shot(scenario_only)
131
+
132
+ gemini_response = call_gemini_pro(raw_outputs, injury_guess)
133
+
134
+ match_cause = re.search(r"Cause of Accident\s*:\s*(.+)", gemini_response)
135
+ match_injury = re.search(r"Degree of Injury\s*:\s*(Low|Medium|High)", gemini_response, re.IGNORECASE)
136
+
137
+ cause = match_cause.group(1).strip() if match_cause else "Unknown"
138
+ injury = match_injury.group(1).strip().capitalize() if match_injury else injury_guess
139
+
140
+ combined_raw = "\n\n".join(raw_outputs)
141
+ return cause, injury, combined_raw
142
+
143
  except Exception as e:
144
+ return "❌ Prediction failed.", "", traceback.format_exc()
145
 
146
  def create_interface():
147
+ with gr.Blocks(title="Workplace Safety Risk Predictor") as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  gr.HTML("""
149
+ <h1>🚧 Workplace Safety Risk Prediction Model (Ensemble)</h1>
150
+ <p>Enter a construction scenario to analyze possible risks.</p>
 
 
151
  """)
152
 
153
  with gr.Row():
154
+ with gr.Column():
155
+ scenario_input = gr.Textbox(lines=5, label="Scenario Description")
156
+ temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Creativity (Temperature)")
157
+ max_len = gr.Slider(100, 500, 300, 50, label="Max Response Length")
158
+ predict_btn = gr.Button("πŸ” Analyze")
159
+
160
+ with gr.Column():
161
+ cause_output = gr.Textbox(label="πŸ“ Cause of Accident")
 
 
 
 
 
 
 
 
 
 
 
162
  degree_output = gr.Textbox(label="πŸ“ˆ Degree of Injury")
163
+ with gr.Accordion("πŸ“„ Raw Model Outputs", open=False):
164
+ raw_output = gr.Textbox(label="Raw Responses", lines=12)
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  predict_btn.click(
167
+ fn=generate_prediction_ensemble,
168
+ inputs=[scenario_input, max_len, temperature],
169
+ outputs=[cause_output, degree_output, raw_output]
170
  )
171
 
172
+ gr.HTML("""<p style='text-align:center;'>Built with πŸ€– Transformers + Gemini Flash + Gradio</p>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  return interface
175
 
176
+ print("πŸš€ Starting app...")
177
+ if load_models():
178
  app = create_interface()
179
  if __name__ == "__main__":
180
  app.launch(server_name="0.0.0.0", server_port=7860, share=True)
181
  else:
182
+ print("❌ Failed to load models.")