FrAnKu34t23 commited on
Commit
f9294e1
Β·
verified Β·
1 Parent(s): 8dc3d14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +334 -67
app.py CHANGED
@@ -1,83 +1,350 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import json
5
  import re
 
 
 
 
6
 
7
- # Models dictionary
8
- model_paths = {
9
- "Model A (OSHA LLaMA)": "FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v1",
10
- "Model B (OSHA LLaMA2)": "FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v2"
11
- }
12
-
13
- def build_prompt(desc):
14
- return f"""Incident Description:
15
- {desc}
16
-
17
- Please extract the following in JSON format:
18
- - Hazards
19
- - Cause of Accident
20
- - Degree of Injury (High, Medium, Low)
21
- - Occupation
22
 
23
- Output must be valid JSON using double quotes.
24
- """
 
25
 
26
- def parse_response(output):
 
 
 
27
  try:
28
- match = re.search(r'\{[\s\S]*?\}', output)
29
- if match:
30
- data = json.loads(match.group().replace("'", '"'))
31
- return json.dumps(data, indent=2)
32
- else:
33
- return "❌ Could not parse JSON from output."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  except Exception as e:
35
- return f"❌ Error parsing output: {e}"
36
-
37
- def generate(incident, model_choice):
38
- print(incident, model_choice)
39
- model_id = model_paths[model_choice]
40
-
41
- # Load model/tokenizer on demand
42
- tokenizer = AutoTokenizer.from_pretrained(model_id)
43
- tokenizer.pad_token = tokenizer.eos_token
44
- model = AutoModelForCausalLM.from_pretrained(
45
- model_id,
46
- device_map=None, # no device map
47
- torch_dtype=torch.float32, # force CPU-friendly dtype
48
- low_cpu_mem_usage=True
49
- ).to("cpu")
50
 
51
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- prompt = build_prompt(incident)
54
- inputs = tokenizer(prompt, return_tensors='pt').to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- with torch.no_grad():
57
- outputs = model.generate(
58
- **inputs,
59
- max_new_tokens=512,
60
- temperature=0.9,
61
- do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
65
- parsed = parse_response(decoded)
66
- return parsed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- demo = gr.Interface(
69
- fn=generate,
70
- inputs=[
71
- gr.Textbox(label="Incident Description", lines=6, placeholder="Describe the incident..."),
72
- gr.Dropdown(choices=list(model_paths.keys()),
73
- label="Choose Model",
74
- value=list(model_paths.keys())[0] # βœ… Set default value
75
- )
76
- ],
77
- outputs=gr.Textbox(label="Extracted JSON"),
78
- title="OSHA Risk Analyzer (CPU)",
79
- description="Runs one OSHA model on CPU and extracts hazards, cause, injury level, and occupation."
80
- )
81
 
82
- if __name__ == "__main__":
83
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  import json
4
  import re
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
 
10
+ # Configuration
11
+ BASE_MODEL_ID = "distilgpt2"
12
+ LORA_MODEL_ID = "FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v3"
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Global variables for model and tokenizer
15
+ model = None
16
+ tokenizer = None
17
 
18
+ def load_model():
19
+ """Load the base model and LoRA adapter"""
20
+ global model, tokenizer
21
+
22
  try:
23
+ print("πŸ” Loading base model and tokenizer...")
24
+
25
+ # Load tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
27
+ if tokenizer.pad_token is None:
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ # Load base model
31
+ base_model = AutoModelForCausalLM.from_pretrained(
32
+ BASE_MODEL_ID,
33
+ torch_dtype=torch.float32,
34
+ device_map="auto",
35
+ trust_remote_code=True
36
+ )
37
+
38
+ print("πŸ“ Loading LoRA adapter...")
39
+
40
+ # Load LoRA adapter from Hugging Face
41
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL_ID)
42
+ model.eval()
43
+
44
+ print("βœ… Model loaded successfully!")
45
+ return True
46
+
47
  except Exception as e:
48
+ print(f"❌ Error loading model: {e}")
49
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def format_input(scenario_text):
52
+ """Format input to match training data format"""
53
+ # Ensure the input starts with ", " like training data
54
+ cleaned_text = scenario_text.strip()
55
+ if not cleaned_text.startswith(", "):
56
+ if cleaned_text.startswith(","):
57
+ cleaned_text = ", " + cleaned_text[1:].strip()
58
+ else:
59
+ cleaned_text = ", " + cleaned_text
60
+
61
+ instruction = "Based on the situation, predict potential hazards and injuries."
62
+ formatted_prompt = f"{instruction} {cleaned_text}"
63
+
64
+ return formatted_prompt
65
 
66
+ def parse_json_response(response_text):
67
+ """Extract and parse JSON from model response"""
68
+ try:
69
+ # First, try to parse the entire response as JSON
70
+ if response_text.strip().startswith('{') and response_text.strip().endswith('}'):
71
+ return json.loads(response_text.strip())
72
+
73
+ # If that fails, look for JSON pattern in the text
74
+ json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
75
+ matches = re.findall(json_pattern, response_text, re.DOTALL)
76
+
77
+ for match in matches:
78
+ try:
79
+ return json.loads(match)
80
+ except:
81
+ continue
82
+
83
+ # If no valid JSON found, return structured error
84
+ return {
85
+ "Hazards": ["Unable to parse response"],
86
+ "Cause of Accident": "Model output parsing failed",
87
+ "Degree of Injury": "Unknown",
88
+ "raw_response": response_text
89
+ }
90
+
91
+ except Exception as e:
92
+ return {
93
+ "Hazards": [f"Parsing error: {str(e)}"],
94
+ "Cause of Accident": "JSON parsing failed",
95
+ "Degree of Injury": "Unknown",
96
+ "raw_response": response_text
97
+ }
98
 
99
+ def generate_prediction(scenario_text, max_length=300, temperature=0.7):
100
+ """Generate workplace safety prediction"""
101
+ global model, tokenizer
102
+
103
+ if model is None or tokenizer is None:
104
+ return "❌ Model not loaded. Please wait for initialization.", "", "", "", ""
105
+
106
+ try:
107
+ # Format the input
108
+ formatted_prompt = format_input(scenario_text)
109
+ full_prompt = f"{formatted_prompt}{tokenizer.eos_token}"
110
+
111
+ # Tokenize
112
+ inputs = tokenizer(
113
+ full_prompt,
114
+ return_tensors="pt",
115
+ truncation=True,
116
+ max_length=512,
117
+ padding=False
118
  )
119
+
120
+ # Move to same device as model
121
+ device = next(model.parameters()).device
122
+ inputs = {k: v.to(device) for k, v in inputs.items()}
123
+
124
+ # Generate response
125
+ with torch.no_grad():
126
+ outputs = model.generate(
127
+ **inputs,
128
+ max_length=len(inputs['input_ids'][0]) + max_length,
129
+ temperature=temperature,
130
+ do_sample=True,
131
+ top_p=0.9,
132
+ top_k=50,
133
+ pad_token_id=tokenizer.pad_token_id,
134
+ eos_token_id=tokenizer.eos_token_id,
135
+ num_return_sequences=1,
136
+ repetition_penalty=1.1,
137
+ early_stopping=True
138
+ )
139
+
140
+ # Decode response
141
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
142
+
143
+ # Extract generated part
144
+ input_text = tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=False)
145
+
146
+ if full_response.startswith(input_text):
147
+ generated_part = full_response[len(input_text):].strip()
148
+ else:
149
+ generated_part = full_response.strip()
150
+
151
+ # Clean up response
152
+ if generated_part.startswith(tokenizer.eos_token):
153
+ generated_part = generated_part[len(tokenizer.eos_token):].strip()
154
+
155
+ if generated_part.endswith(tokenizer.eos_token):
156
+ generated_part = generated_part[:-len(tokenizer.eos_token)].strip()
157
+
158
+ # Parse the JSON response
159
+ parsed_response = parse_json_response(generated_part)
160
+
161
+ # Extract individual components
162
+ hazards = parsed_response.get("Hazards", [])
163
+ cause = parsed_response.get("Cause of Accident", "Not specified")
164
+ degree = parsed_response.get("Degree of Injury", "Not specified")
165
+
166
+ # Format hazards for display
167
+ hazards_display = ", ".join(hazards) if isinstance(hazards, list) else str(hazards)
168
+
169
+ # Create formatted output
170
+ formatted_output = json.dumps(parsed_response, indent=2, ensure_ascii=False)
171
+
172
+ return hazards_display, cause, degree, formatted_output, generated_part
173
+
174
+ except Exception as e:
175
+ error_msg = f"❌ Error generating prediction: {str(e)}"
176
+ return error_msg, "", "", "", ""
177
 
178
+ def create_interface():
179
+ """Create the Gradio interface"""
180
+
181
+ # Custom CSS for better styling
182
+ css = """
183
+ .gradio-container {
184
+ font-family: 'Arial', sans-serif;
185
+ }
186
+ .header {
187
+ text-align: center;
188
+ margin-bottom: 30px;
189
+ }
190
+ .warning-box {
191
+ background-color: #fff3cd;
192
+ border: 1px solid #ffeaa7;
193
+ border-radius: 5px;
194
+ padding: 15px;
195
+ margin: 10px 0;
196
+ }
197
+ """
198
+
199
+ with gr.Blocks(css=css, title="Workplace Safety Risk Predictor") as interface:
200
+
201
+ gr.HTML("""
202
+ <div class="header">
203
+ <h1>🚧 Workplace Safety Risk Prediction Model</h1>
204
+ <p>Analyze workplace scenarios to identify potential hazards, causes, and injury severity</p>
205
+ </div>
206
+ """)
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=2):
210
+ gr.HTML("<h3>πŸ“ Enter Workplace Scenario</h3>")
211
+
212
+ scenario_input = gr.Textbox(
213
+ lines=5,
214
+ placeholder="Example: an employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation...",
215
+ label="Workplace Incident Description",
216
+ info="Describe the workplace scenario you want to analyze"
217
+ )
218
+
219
+ with gr.Row():
220
+ with gr.Column():
221
+ temperature = gr.Slider(
222
+ minimum=0.1,
223
+ maximum=1.0,
224
+ value=0.7,
225
+ step=0.1,
226
+ label="Creativity (Temperature)",
227
+ info="Higher values = more creative responses"
228
+ )
229
+
230
+ with gr.Column():
231
+ max_length = gr.Slider(
232
+ minimum=100,
233
+ maximum=500,
234
+ value=300,
235
+ step=50,
236
+ label="Max Response Length",
237
+ info="Maximum length of generated response"
238
+ )
239
+
240
+ predict_btn = gr.Button("πŸ” Analyze Scenario", variant="primary", size="lg")
241
+
242
+ gr.HTML("""
243
+ <div class="warning-box">
244
+ <strong>⚠️ Note:</strong> This is an AI model for educational purposes.
245
+ Always consult safety professionals for real workplace safety assessments.
246
+ </div>
247
+ """)
248
+
249
+ with gr.Column(scale=2):
250
+ gr.HTML("<h3>πŸ“Š Analysis Results</h3>")
251
+
252
+ with gr.Row():
253
+ with gr.Column():
254
+ hazards_output = gr.Textbox(
255
+ label="🚨 Identified Hazards",
256
+ info="Potential hazards identified in the scenario"
257
+ )
258
+
259
+ cause_output = gr.Textbox(
260
+ label="πŸ” Cause of Accident",
261
+ info="Primary cause classification"
262
+ )
263
+
264
+ degree_output = gr.Textbox(
265
+ label="πŸ“ˆ Degree of Injury",
266
+ info="Severity assessment"
267
+ )
268
+
269
+ with gr.Accordion("πŸ“‹ Detailed JSON Output", open=False):
270
+ json_output = gr.Code(
271
+ label="Structured Response",
272
+ language="json"
273
+ )
274
+
275
+ with gr.Accordion("πŸ” Raw Model Output", open=False):
276
+ raw_output = gr.Textbox(
277
+ label="Raw Response",
278
+ lines=3,
279
+ info="Unprocessed model output"
280
+ )
281
+
282
+ # Example scenarios
283
+ gr.HTML("<h3>πŸ’‘ Example Scenarios</h3>")
284
+
285
+ with gr.Row():
286
+ example1 = gr.Button("Power Press Accident")
287
+ example2 = gr.Button("Fall from Ladder")
288
+ example3 = gr.Button("Chemical Exposure")
289
+ example4 = gr.Button("Lifting Injury")
290
+
291
+ # Event handlers
292
+ predict_btn.click(
293
+ fn=generate_prediction,
294
+ inputs=[scenario_input, max_length, temperature],
295
+ outputs=[hazards_output, cause_output, degree_output, json_output, raw_output]
296
+ )
297
+
298
+ # Example scenarios
299
+ example1.click(
300
+ lambda: "an employee was operating a 400 ton mechanical power press. The press was actuated while the employee's right hand was in the point of operation. The employee's fingers were amputated.",
301
+ outputs=scenario_input
302
+ )
303
+
304
+ example2.click(
305
+ lambda: "an employee was using a ladder to access high shelves. The ladder was not properly secured and the employee fell from a height of 8 feet, resulting in head injuries.",
306
+ outputs=scenario_input
307
+ )
308
+
309
+ example3.click(
310
+ lambda: "an employee was working with chemical solvents without proper ventilation. The employee inhaled toxic fumes and experienced respiratory problems.",
311
+ outputs=scenario_input
312
+ )
313
+
314
+ example4.click(
315
+ lambda: "an employee was manually lifting heavy boxes weighing over 50 pounds without proper lifting technique or mechanical aids. The employee strained their back.",
316
+ outputs=scenario_input
317
+ )
318
+
319
+ gr.HTML("""
320
+ <div style="text-align: center; margin-top: 30px; color: #666;">
321
+ <p>Built with ❀️ using Hugging Face Transformers and Gradio</p>
322
+ <p>Model: <a href="https://huggingface.co/FrAnKu34t23/Construction_Mistral_Risk_Prediction_Model_v3">Construction_Mistral_Risk_Prediction_Model_v3</a></p>
323
+ </div>
324
+ """)
325
+
326
+ return interface
327
 
328
+ # Initialize the model when the app starts
329
+ print("πŸš€ Initializing Workplace Safety Risk Prediction App...")
330
+ model_loaded = load_model()
 
 
 
 
 
 
 
 
 
 
331
 
332
+ if model_loaded:
333
+ print("βœ… App ready!")
334
+ # Create and launch the interface
335
+ app = create_interface()
336
+
337
+ if __name__ == "__main__":
338
+ app.launch(
339
+ server_name="0.0.0.0",
340
+ server_port=7860,
341
+ share=True
342
+ )
343
+ else:
344
+ print("❌ Failed to load model. App cannot start.")
345
+ # Create a simple error interface
346
+ with gr.Blocks() as error_app:
347
+ gr.HTML("<h1>❌ Model Loading Failed</h1><p>Unable to load the safety prediction model.</p>")
348
+
349
+ if __name__ == "__main__":
350
+ error_app.launch()