Spaces:
Sleeping
Sleeping
| import torch | |
| import re | |
| import gradio as gr | |
| import json | |
| import traceback | |
| import ast | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| from peft import PeftModel | |
| # === CONFIGURATION === | |
| MODEL_PATHS = [ | |
| "FrAnKu34t23/Construction_Risk_Prediction_Model_v3", | |
| ] | |
| BASE_MODEL_ID = "distilgpt2" | |
| ANALYSIS_MODEL_ID = "google/flan-t5-base" | |
| # === LOAD MODELS === | |
| models, tokenizers = [], [] | |
| for path in MODEL_PATHS: | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) | |
| base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID) | |
| model = PeftModel.from_pretrained(base_model, path).to("cpu").eval() | |
| models.append(model) | |
| tokenizers.append(tokenizer) | |
| # === FLAN-T5 for CPU ANALYSIS === | |
| flan_tokenizer = AutoTokenizer.from_pretrained(ANALYSIS_MODEL_ID) | |
| flan_model = AutoModelForSeq2SeqLM.from_pretrained(ANALYSIS_MODEL_ID).to("cpu").eval() | |
| # === FORMAT INPUT === | |
| def format_input(scenario_text): | |
| scenario = scenario_text.strip() | |
| if not scenario.startswith(", "): | |
| scenario = ", " + scenario.lstrip(", ") | |
| return f"Based on the situation, predict potential hazards and injuries. {scenario}<|endoftext|>" | |
| # === TEXT CLEANING + JSON EXTRACTION (FALLBACK USE) === | |
| def clean_raw_json_string(raw_text): | |
| cleaned = raw_text.replace("β", "'").replace("β", "'") | |
| cleaned = cleaned.replace("β", '"').replace("β", '"') | |
| cleaned = cleaned.replace("''", '"').replace("`", '"').replace("β ", "") | |
| cleaned = re.sub(r'([{{\[,])\s*"', r'\1 "', cleaned) | |
| cleaned = re.sub(r'"\s*([}}\],])', r'" \1', cleaned) | |
| return cleaned | |
| def extract_json_object(text): | |
| pattern = r'\{(?:[^{{}}]|"[^"]*")*\}' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| for match in matches: | |
| try: | |
| cleaned = clean_raw_json_string(match) | |
| hazard_items = re.findall(r'\["([^"]+)"\]', cleaned) | |
| cleaned = re.sub(r'(\["[^"]+"\]\s*,?\s*)+', '', cleaned) | |
| if hazard_items and "Hazards" not in cleaned: | |
| cleaned = cleaned.rstrip('} \n\t,') | |
| cleaned += ', "Hazards": ' + json.dumps(hazard_items) + '}' | |
| parsed = json.loads(cleaned) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except Exception as e: | |
| print(f"β οΈ extract_json_object failed: {e}") | |
| continue | |
| return None | |
| def extract_fields(text): | |
| def clean_text(t): | |
| t = t.replace("β", "'").replace("β", "'").replace("β", '"').replace("β", '"') | |
| t = t.replace("''", '"').replace("`", '"').replace("β ", "").replace("Β΄", "") | |
| t = re.sub(r"[^\x00-\x7F]+", "", t) | |
| return t | |
| cleaned = clean_text(text) | |
| cause = "Unknown" | |
| injury = "Unknown" | |
| hazards = [] | |
| match = re.search(r'"?Cause of Accident"?\s*:\s*"([^"]+)",?', cleaned, re.IGNORECASE) | |
| if match: | |
| cause = match.group(1).strip() | |
| match = re.search(r'"?Degree of Injury"?\s*:\s*"(Low|Medium|High)"', cleaned, re.IGNORECASE) | |
| if match: | |
| injury = match.group(1).capitalize() | |
| match = re.search(r'"?Hazards"?\s*:\s*(\[[^\]]+\])', cleaned, re.IGNORECASE) | |
| if match: | |
| try: | |
| hazards_raw = clean_text(match.group(1)) | |
| hazards = ast.literal_eval(hazards_raw) | |
| hazards = [str(h).strip().strip('"').strip("'") for h in hazards] | |
| except Exception as e: | |
| print("β οΈ Hazard parsing failed:", e) | |
| hazards = [] | |
| structured = { | |
| "Hazards": hazards, | |
| "Cause of Accident": cause, | |
| "Degree of Injury": injury | |
| } | |
| return hazards, cause, injury, json.dumps(structured, indent=2) | |
| def extract_json_only(text): | |
| pattern = r'\{(?:[^{}]|"[^"]*")*\}' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| return matches[0] if matches else "" | |
| # === GENERATION FROM EACH MODEL === | |
| def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_length=inputs["input_ids"].shape[1] + max_length, | |
| temperature=temperature, | |
| top_p=0.9, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True | |
| ) | |
| return tokenizer.decode(output[0], skip_special_tokens=True).strip() | |
| # === ANALYSIS WITH FLAN-T5 === | |
| def analyze_with_cpu_model(raw_outputs): | |
| json_blobs = [] | |
| for i, text in enumerate(raw_outputs): | |
| json_part = extract_json_only(text) | |
| if json_part: | |
| json_blobs.append(f"Model {i+1} JSON:\n{json_part}") | |
| summary = "\n\n".join(json_blobs) | |
| prompt = ( | |
| f"The following are JSON outputs from multiple hazard prediction models:\n\n" | |
| f"{summary}\n\n" | |
| f"Please analyze all JSON outputs and return:\n" | |
| f"Cause of Accident: <natural language summary of the most likely cause>\n" | |
| f"Degree of Injury: <Low | Medium | High>" | |
| ) | |
| inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu") | |
| with torch.no_grad(): | |
| output = flan_model.generate( | |
| **inputs, | |
| max_length=128, | |
| temperature=0.5, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| decoded = flan_tokenizer.decode(output[0], skip_special_tokens=True).strip() | |
| match_cause = re.search(r"(?i)cause of accident:\s*(.+)", decoded) | |
| match_injury = re.search(r"(?i)degree of injury:\s*(low|medium|high)", decoded) | |
| cause = match_cause.group(1).strip() if match_cause else "Cause not found." | |
| injury = match_injury.group(1).capitalize() if match_injury else "Unknown" | |
| return cause, injury | |
| # === MAIN GENERATION FUNCTION === | |
| def generate_prediction_ensemble(scenario_text, max_len, temperature): | |
| if not scenario_text.strip(): | |
| return "Please enter a scenario", "", "" | |
| prompt = format_input(scenario_text) | |
| raw_outputs = [ | |
| generate_single_model_output(model, tokenizer, prompt, max_length=max_len, temperature=temperature) | |
| for model, tokenizer in zip(models, tokenizers) | |
| ] | |
| cause, injury = analyze_with_cpu_model(raw_outputs) | |
| raw_output_text = "\n\n".join([f"Model {i+1}:\n{resp}" for i, resp in enumerate(raw_outputs)]) | |
| return cause, injury, raw_output_text | |
| # === FULL GRADIO INTERFACE === | |
| def create_interface(): | |
| with gr.Blocks(title="Multi-Model Safety Risk Predictor") as interface: | |
| gr.HTML(f""" | |
| <h1>π§ Multi-Model Safety Risk Predictor (CPU-Only)</h1> | |
| <p><strong>System Overview:</strong></p> | |
| <ul> | |
| <li>Loads {len(MODEL_PATHS)} specialized safety prediction models</li> | |
| <li>Each model analyzes the scenario independently</li> | |
| <li>CPU-only analysis model integrates all results using advanced reasoning</li> | |
| <li>Handles conflicting predictions through pattern analysis and majority consensus</li> | |
| <li>Fully optimized for CPU-only Hugging Face Spaces</li> | |
| </ul> | |
| <p><strong>Models Loaded:</strong> {len(models)} / {len(MODEL_PATHS)}</p> | |
| <p><strong>Base Model:</strong> {BASE_MODEL_ID}</p> | |
| <p><strong>Analysis Method:</strong> CPU-Only (No external API calls)</p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| scenario_input = gr.Textbox( | |
| lines=6, | |
| label="Construction Scenario Description", | |
| placeholder="Describe the workplace safety incident or scenario..." | |
| ) | |
| gr.Markdown("**Quick Examples:**") | |
| with gr.Row(): | |
| ex1 = gr.Button("Chemical Exposure", size="sm") | |
| ex2 = gr.Button("Fall Hazard", size="sm") | |
| ex3 = gr.Button("Equipment Malfunction", size="sm") | |
| ex4 = gr.Button("Fire Incident", size="sm") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.0, 0.7, 0.1, label="Model Creativity") | |
| max_len = gr.Slider(100, 400, 300, 50, label="Response Length") | |
| predict_btn = gr.Button("π Analyze with Multi-Model Ensemble", variant="primary") | |
| with gr.Column(): | |
| cause_output = gr.Textbox( | |
| label="π Integrated Cause Analysis", | |
| lines=4, | |
| info="CPU model's integrated analysis of all model outputs" | |
| ) | |
| degree_output = gr.Textbox( | |
| label="π Degree of Injury", | |
| info="Based on zero-shot classification + model integration" | |
| ) | |
| with gr.Accordion("π Individual Model Outputs", open=False): | |
| raw_output = gr.Textbox(label="Raw Model Responses", lines=15) | |
| # Button action | |
| predict_btn.click( | |
| fn=generate_prediction_ensemble, | |
| inputs=[scenario_input, max_len, temperature], | |
| outputs=[cause_output, degree_output, raw_output] | |
| ) | |
| # Quick examples | |
| ex1.click(fn=lambda: "An employee was working with chemical solvents in a poorly ventilated area without proper respiratory protection. The worker began experiencing dizziness and respiratory distress after prolonged exposure.", outputs=scenario_input) | |
| ex2.click(fn=lambda: "A construction worker was installing roofing materials on a steep slope without proper fall protection equipment. The worker lost footing on wet materials and fell.", outputs=scenario_input) | |
| ex3.click(fn=lambda: "During routine maintenance, a hydraulic press malfunctioned due to worn seals. The operator's hand was caught when the press unexpectedly activated.", outputs=scenario_input) | |
| ex4.click(fn=lambda: "While welding in an area with flammable materials, proper fire safety protocols were not followed. Sparks ignited nearby combustible materials causing a flash fire.", outputs=scenario_input) | |
| gr.HTML(f""" | |
| <div style='text-align:center; margin-top:20px;'> | |
| <p><strong>System Status:</strong> {len(models)} models loaded | CPU-optimized | No external APIs</p> | |
| <p><em>Built with Multi-Model Ensemble + CPU Analysis + Gradio</em></p> | |
| </div> | |
| """) | |
| return interface | |
| # === RUN INTERFACE === | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |