import torch import re import gradio as gr import json import traceback import ast from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM ) from peft import PeftModel # === CONFIGURATION === MODEL_PATHS = [ "FrAnKu34t23/Construction_Risk_Prediction_Model_v3", ] BASE_MODEL_ID = "distilgpt2" ANALYSIS_MODEL_ID = "google/flan-t5-base" # === LOAD MODELS === models, tokenizers = [], [] for path in MODEL_PATHS: tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID) model = PeftModel.from_pretrained(base_model, path).to("cpu").eval() models.append(model) tokenizers.append(tokenizer) # === FLAN-T5 for CPU ANALYSIS === flan_tokenizer = AutoTokenizer.from_pretrained(ANALYSIS_MODEL_ID) flan_model = AutoModelForSeq2SeqLM.from_pretrained(ANALYSIS_MODEL_ID).to("cpu").eval() # === FORMAT INPUT === def format_input(scenario_text): scenario = scenario_text.strip() if not scenario.startswith(", "): scenario = ", " + scenario.lstrip(", ") return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>" # === TEXT CLEANING + JSON EXTRACTION (FALLBACK USE) === def clean_raw_json_string(raw_text): cleaned = raw_text.replace("‘", "'").replace("’", "'") cleaned = cleaned.replace("“", '"').replace("”", '"') cleaned = cleaned.replace("''", '"').replace("`", '"').replace("†", "") cleaned = re.sub(r'([{{\[,])\s*"', r'\1 "', cleaned) cleaned = re.sub(r'"\s*([}}\],])', r'" \1', cleaned) return cleaned def extract_json_object(text): pattern = r'\{(?:[^{{}}]|"[^"]*")*\}' matches = re.findall(pattern, text, re.DOTALL) for match in matches: try: cleaned = clean_raw_json_string(match) hazard_items = re.findall(r'\["([^"]+)"\]', cleaned) cleaned = re.sub(r'(\["[^"]+"\]\s*,?\s*)+', '', cleaned) if hazard_items and "Hazards" not in cleaned: cleaned = cleaned.rstrip('} \n\t,') cleaned += ', "Hazards": ' + json.dumps(hazard_items) + '}' parsed = json.loads(cleaned) if isinstance(parsed, dict): return parsed except Exception as e: print(f"⚠️ extract_json_object failed: {e}") continue return None def extract_fields(text): def clean_text(t): t = t.replace("‘", "'").replace("’", "'").replace("“", '"').replace("”", '"') t = t.replace("''", '"').replace("`", '"').replace("†", "").replace("´", "") t = re.sub(r"[^\x00-\x7F]+", "", t) return t cleaned = clean_text(text) cause = "Unknown" injury = "Unknown" hazards = [] match = re.search(r'"?Cause of Accident"?\s*:\s*"([^"]+)",?', cleaned, re.IGNORECASE) if match: cause = match.group(1).strip() match = re.search(r'"?Degree of Injury"?\s*:\s*"(Low|Medium|High)"', cleaned, re.IGNORECASE) if match: injury = match.group(1).capitalize() match = re.search(r'"?Hazards"?\s*:\s*(\[[^\]]+\])', cleaned, re.IGNORECASE) if match: try: hazards_raw = clean_text(match.group(1)) hazards = ast.literal_eval(hazards_raw) hazards = [str(h).strip().strip('"').strip("'") for h in hazards] except Exception as e: print("⚠️ Hazard parsing failed:", e) hazards = [] structured = { "Hazards": hazards, "Cause of Accident": cause, "Degree of Injury": injury } return hazards, cause, injury, json.dumps(structured, indent=2) def extract_json_only(text): pattern = r'\{(?:[^{}]|"[^"]*")*\}' matches = re.findall(pattern, text, re.DOTALL) return matches[0] if matches else "" # === GENERATION FROM EACH MODEL === def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7): inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu") with torch.no_grad(): output = model.generate( **inputs, max_length=inputs["input_ids"].shape[1] + max_length, temperature=temperature, top_p=0.9, top_k=50, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id, do_sample=True ) return tokenizer.decode(output[0], skip_special_tokens=True).strip() # === ANALYSIS WITH FLAN-T5 === def analyze_with_cpu_model(raw_outputs): json_blobs = [] for i, text in enumerate(raw_outputs): json_part = extract_json_only(text) if json_part: json_blobs.append(f"Model {i+1} JSON:\n{json_part}") summary = "\n\n".join(json_blobs) prompt = ( f"The following are JSON outputs from multiple hazard prediction models:\n\n" f"{summary}\n\n" f"Please analyze all JSON outputs and return:\n" f"Cause of Accident: \n" f"Degree of Injury: " ) inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu") with torch.no_grad(): output = flan_model.generate( **inputs, max_length=128, temperature=0.5, top_p=0.9, do_sample=True ) decoded = flan_tokenizer.decode(output[0], skip_special_tokens=True).strip() match_cause = re.search(r"(?i)cause of accident:\s*(.+)", decoded) match_injury = re.search(r"(?i)degree of injury:\s*(low|medium|high)", decoded) cause = match_cause.group(1).strip() if match_cause else "Cause not found." injury = match_injury.group(1).capitalize() if match_injury else "Unknown" return cause, injury # === MAIN GENERATION FUNCTION === def generate_prediction_ensemble(scenario_text, max_len, temperature): if not scenario_text.strip(): return "Please enter a scenario", "", "" prompt = format_input(scenario_text) raw_outputs = [ generate_single_model_output(model, tokenizer, prompt, max_length=max_len, temperature=temperature) for model, tokenizer in zip(models, tokenizers) ] cause, injury = analyze_with_cpu_model(raw_outputs) raw_output_text = "\n\n".join([f"Model {i+1}:\n{resp}" for i, resp in enumerate(raw_outputs)]) return cause, injury, raw_output_text # === FULL GRADIO INTERFACE === def create_interface(): with gr.Blocks(title="Multi-Model Safety Risk Predictor") as interface: gr.HTML(f"""

🚧 Multi-Model Safety Risk Predictor (CPU-Only)

System Overview:

Models Loaded: {len(models)} / {len(MODEL_PATHS)}

Base Model: {BASE_MODEL_ID}

Analysis Method: CPU-Only (No external API calls)

""") with gr.Row(): with gr.Column(): scenario_input = gr.Textbox( lines=6, label="Construction Scenario Description", placeholder="Describe the workplace safety incident or scenario..." ) gr.Markdown("**Quick Examples:**") with gr.Row(): ex1 = gr.Button("Chemical Exposure", size="sm") ex2 = gr.Button("Fall Hazard", size="sm") ex3 = gr.Button("Equipment Malfunction", size="sm") ex4 = gr.Button("Fire Incident", size="sm") with gr.Row(): temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Model Creativity") max_len = gr.Slider(100, 400, 300, 50, label="Response Length") predict_btn = gr.Button("🔍 Analyze with Multi-Model Ensemble", variant="primary") with gr.Column(): cause_output = gr.Textbox( label="📝 Integrated Cause Analysis", lines=4, info="CPU model's integrated analysis of all model outputs" ) degree_output = gr.Textbox( label="📈 Degree of Injury", info="Based on zero-shot classification + model integration" ) with gr.Accordion("📄 Individual Model Outputs", open=False): raw_output = gr.Textbox(label="Raw Model Responses", lines=15) # Button action predict_btn.click( fn=generate_prediction_ensemble, inputs=[scenario_input, max_len, temperature], outputs=[cause_output, degree_output, raw_output] ) # Quick examples ex1.click(fn=lambda: "An employee was working with chemical solvents in a poorly ventilated area without proper respiratory protection. The worker began experiencing dizziness and respiratory distress after prolonged exposure.", outputs=scenario_input) ex2.click(fn=lambda: "A construction worker was installing roofing materials on a steep slope without proper fall protection equipment. The worker lost footing on wet materials and fell.", outputs=scenario_input) ex3.click(fn=lambda: "During routine maintenance, a hydraulic press malfunctioned due to worn seals. The operator's hand was caught when the press unexpectedly activated.", outputs=scenario_input) ex4.click(fn=lambda: "While welding in an area with flammable materials, proper fire safety protocols were not followed. Sparks ignited nearby combustible materials causing a flash fire.", outputs=scenario_input) gr.HTML(f"""

System Status: {len(models)} models loaded | CPU-optimized | No external APIs

Built with Multi-Model Ensemble + CPU Analysis + Gradio

""") return interface # === RUN INTERFACE === if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)