FrAnKu34t23's picture
Update app.py
e9a167b verified
import torch
import re
import gradio as gr
import json
import traceback
import ast
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM
)
from peft import PeftModel
# === CONFIGURATION ===
MODEL_PATHS = [
"FrAnKu34t23/Construction_Risk_Prediction_Model_v3",
]
BASE_MODEL_ID = "distilgpt2"
ANALYSIS_MODEL_ID = "google/flan-t5-base"
# === LOAD MODELS ===
models, tokenizers = [], []
for path in MODEL_PATHS:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
model = PeftModel.from_pretrained(base_model, path).to("cpu").eval()
models.append(model)
tokenizers.append(tokenizer)
# === FLAN-T5 for CPU ANALYSIS ===
flan_tokenizer = AutoTokenizer.from_pretrained(ANALYSIS_MODEL_ID)
flan_model = AutoModelForSeq2SeqLM.from_pretrained(ANALYSIS_MODEL_ID).to("cpu").eval()
# === FORMAT INPUT ===
def format_input(scenario_text):
scenario = scenario_text.strip()
if not scenario.startswith(", "):
scenario = ", " + scenario.lstrip(", ")
return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>"
# === TEXT CLEANING + JSON EXTRACTION (FALLBACK USE) ===
def clean_raw_json_string(raw_text):
cleaned = raw_text.replace("β€˜", "'").replace("’", "'")
cleaned = cleaned.replace("β€œ", '"').replace("”", '"')
cleaned = cleaned.replace("''", '"').replace("`", '"').replace("†", "")
cleaned = re.sub(r'([{{\[,])\s*"', r'\1 "', cleaned)
cleaned = re.sub(r'"\s*([}}\],])', r'" \1', cleaned)
return cleaned
def extract_json_object(text):
pattern = r'\{(?:[^{{}}]|"[^"]*")*\}'
matches = re.findall(pattern, text, re.DOTALL)
for match in matches:
try:
cleaned = clean_raw_json_string(match)
hazard_items = re.findall(r'\["([^"]+)"\]', cleaned)
cleaned = re.sub(r'(\["[^"]+"\]\s*,?\s*)+', '', cleaned)
if hazard_items and "Hazards" not in cleaned:
cleaned = cleaned.rstrip('} \n\t,')
cleaned += ', "Hazards": ' + json.dumps(hazard_items) + '}'
parsed = json.loads(cleaned)
if isinstance(parsed, dict):
return parsed
except Exception as e:
print(f"⚠️ extract_json_object failed: {e}")
continue
return None
def extract_fields(text):
def clean_text(t):
t = t.replace("β€˜", "'").replace("’", "'").replace("β€œ", '"').replace("”", '"')
t = t.replace("''", '"').replace("`", '"').replace("†", "").replace("Β΄", "")
t = re.sub(r"[^\x00-\x7F]+", "", t)
return t
cleaned = clean_text(text)
cause = "Unknown"
injury = "Unknown"
hazards = []
match = re.search(r'"?Cause of Accident"?\s*:\s*"([^"]+)",?', cleaned, re.IGNORECASE)
if match:
cause = match.group(1).strip()
match = re.search(r'"?Degree of Injury"?\s*:\s*"(Low|Medium|High)"', cleaned, re.IGNORECASE)
if match:
injury = match.group(1).capitalize()
match = re.search(r'"?Hazards"?\s*:\s*(\[[^\]]+\])', cleaned, re.IGNORECASE)
if match:
try:
hazards_raw = clean_text(match.group(1))
hazards = ast.literal_eval(hazards_raw)
hazards = [str(h).strip().strip('"').strip("'") for h in hazards]
except Exception as e:
print("⚠️ Hazard parsing failed:", e)
hazards = []
structured = {
"Hazards": hazards,
"Cause of Accident": cause,
"Degree of Injury": injury
}
return hazards, cause, injury, json.dumps(structured, indent=2)
def extract_json_only(text):
pattern = r'\{(?:[^{}]|"[^"]*")*\}'
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else ""
# === GENERATION FROM EACH MODEL ===
def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
with torch.no_grad():
output = model.generate(
**inputs,
max_length=inputs["input_ids"].shape[1] + max_length,
temperature=temperature,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True).strip()
# === ANALYSIS WITH FLAN-T5 ===
def analyze_with_cpu_model(raw_outputs):
json_blobs = []
for i, text in enumerate(raw_outputs):
json_part = extract_json_only(text)
if json_part:
json_blobs.append(f"Model {i+1} JSON:\n{json_part}")
summary = "\n\n".join(json_blobs)
prompt = (
f"The following are JSON outputs from multiple hazard prediction models:\n\n"
f"{summary}\n\n"
f"Please analyze all JSON outputs and return:\n"
f"Cause of Accident: <natural language summary of the most likely cause>\n"
f"Degree of Injury: <Low | Medium | High>"
)
inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
with torch.no_grad():
output = flan_model.generate(
**inputs,
max_length=128,
temperature=0.5,
top_p=0.9,
do_sample=True
)
decoded = flan_tokenizer.decode(output[0], skip_special_tokens=True).strip()
match_cause = re.search(r"(?i)cause of accident:\s*(.+)", decoded)
match_injury = re.search(r"(?i)degree of injury:\s*(low|medium|high)", decoded)
cause = match_cause.group(1).strip() if match_cause else "Cause not found."
injury = match_injury.group(1).capitalize() if match_injury else "Unknown"
return cause, injury
# === MAIN GENERATION FUNCTION ===
def generate_prediction_ensemble(scenario_text, max_len, temperature):
if not scenario_text.strip():
return "Please enter a scenario", "", ""
prompt = format_input(scenario_text)
raw_outputs = [
generate_single_model_output(model, tokenizer, prompt, max_length=max_len, temperature=temperature)
for model, tokenizer in zip(models, tokenizers)
]
cause, injury = analyze_with_cpu_model(raw_outputs)
raw_output_text = "\n\n".join([f"Model {i+1}:\n{resp}" for i, resp in enumerate(raw_outputs)])
return cause, injury, raw_output_text
# === FULL GRADIO INTERFACE ===
def create_interface():
with gr.Blocks(title="Multi-Model Safety Risk Predictor") as interface:
gr.HTML(f"""
<h1>🚧 Multi-Model Safety Risk Predictor (CPU-Only)</h1>
<p><strong>System Overview:</strong></p>
<ul>
<li>Loads {len(MODEL_PATHS)} specialized safety prediction models</li>
<li>Each model analyzes the scenario independently</li>
<li>CPU-only analysis model integrates all results using advanced reasoning</li>
<li>Handles conflicting predictions through pattern analysis and majority consensus</li>
<li>Fully optimized for CPU-only Hugging Face Spaces</li>
</ul>
<p><strong>Models Loaded:</strong> {len(models)} / {len(MODEL_PATHS)}</p>
<p><strong>Base Model:</strong> {BASE_MODEL_ID}</p>
<p><strong>Analysis Method:</strong> CPU-Only (No external API calls)</p>
""")
with gr.Row():
with gr.Column():
scenario_input = gr.Textbox(
lines=6,
label="Construction Scenario Description",
placeholder="Describe the workplace safety incident or scenario..."
)
gr.Markdown("**Quick Examples:**")
with gr.Row():
ex1 = gr.Button("Chemical Exposure", size="sm")
ex2 = gr.Button("Fall Hazard", size="sm")
ex3 = gr.Button("Equipment Malfunction", size="sm")
ex4 = gr.Button("Fire Incident", size="sm")
with gr.Row():
temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Model Creativity")
max_len = gr.Slider(100, 400, 300, 50, label="Response Length")
predict_btn = gr.Button("πŸ” Analyze with Multi-Model Ensemble", variant="primary")
with gr.Column():
cause_output = gr.Textbox(
label="πŸ“ Integrated Cause Analysis",
lines=4,
info="CPU model's integrated analysis of all model outputs"
)
degree_output = gr.Textbox(
label="πŸ“ˆ Degree of Injury",
info="Based on zero-shot classification + model integration"
)
with gr.Accordion("πŸ“„ Individual Model Outputs", open=False):
raw_output = gr.Textbox(label="Raw Model Responses", lines=15)
# Button action
predict_btn.click(
fn=generate_prediction_ensemble,
inputs=[scenario_input, max_len, temperature],
outputs=[cause_output, degree_output, raw_output]
)
# Quick examples
ex1.click(fn=lambda: "An employee was working with chemical solvents in a poorly ventilated area without proper respiratory protection. The worker began experiencing dizziness and respiratory distress after prolonged exposure.", outputs=scenario_input)
ex2.click(fn=lambda: "A construction worker was installing roofing materials on a steep slope without proper fall protection equipment. The worker lost footing on wet materials and fell.", outputs=scenario_input)
ex3.click(fn=lambda: "During routine maintenance, a hydraulic press malfunctioned due to worn seals. The operator's hand was caught when the press unexpectedly activated.", outputs=scenario_input)
ex4.click(fn=lambda: "While welding in an area with flammable materials, proper fire safety protocols were not followed. Sparks ignited nearby combustible materials causing a flash fire.", outputs=scenario_input)
gr.HTML(f"""
<div style='text-align:center; margin-top:20px;'>
<p><strong>System Status:</strong> {len(models)} models loaded | CPU-optimized | No external APIs</p>
<p><em>Built with Multi-Model Ensemble + CPU Analysis + Gradio</em></p>
</div>
""")
return interface
# === RUN INTERFACE ===
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)