Spaces:
Running
Running
| """ | |
| HaramGuard โ CoordinatorAgent | |
| ================================ | |
| AISA Layer : Reasoning + Governance | |
| Design Pattern : ReAct (Reason โ Act โ Observe) + Output Guardrails | |
| ReAct Pattern Implementation (Required by Capstone Rubric): | |
| This agent implements the ReAct (Reasoning-Acting-Observing) design pattern | |
| as explicitly required by the capstone project rubric. The ReAct loop enables | |
| iterative self-correction through: | |
| 1. REASON: Analyze the situation and compose a structured prompt | |
| - Input: RiskResult, Decision, recent frames, feedback (if any) | |
| - Output: Contextualized prompt for LLM | |
| 2. ACT: Execute the action (call LLM to generate action plan) | |
| - Model: Groq API (GPT-OSS-120B or similar) | |
| - Output: Raw JSON plan from LLM | |
| 3. OBSERVE: Validate the output using guardrails | |
| - Run 5 validation checks (GR-C1 through GR-C5) | |
| - If issues found: generate feedback and loop back to REASON | |
| - If valid: return the plan | |
| The loop continues up to MAX_REACT_ITERS (3) times, ensuring the agent | |
| can self-correct errors in its reasoning or output format. | |
| Responsibilities: | |
| - Called on ALL decisions (P0/P1/P2) โ not just critical emergencies | |
| - Implements a ReAct loop (max 3 iterations): | |
| Reason : analyse the crowd situation and compose a prompt | |
| Act : call LLM (Groq) to generate a structured action plan | |
| Observe : run 6 guardrails to validate the output | |
| โ if validation fails, feed back issues and Reason again | |
| - Guardrails: | |
| GR-C1: Required fields check | |
| GR-C2: Valid threat level (CRITICAL/HIGH/MEDIUM/LOW only) | |
| GR-C3: Confidence score in [0, 1] | |
| GR-C4: Consistency check (low risk score โ CRITICAL threat) | |
| GR-C5: Arabic alert fallback if empty | |
| GR-C6: selected_gates must be a non-empty list | |
| """ | |
| import json | |
| import numpy as np | |
| from typing import Optional, Tuple | |
| from groq import Groq | |
| from openai import OpenAI | |
| from core.models import RiskResult, Decision | |
| class CoordinatorAgent: | |
| REQUIRED_FIELDS = { | |
| 'threat_level', 'executive_summary', 'selected_gates', | |
| 'immediate_actions', 'actions_justification', | |
| 'arabic_alert', 'confidence_score' | |
| } | |
| VALID_THREATS = {'CRITICAL', 'HIGH', 'MEDIUM', 'LOW'} | |
| MAX_REACT_ITERS = 3 # ReAct: maximum reasoning iterations | |
| # Real gate list โ injected into every LLM prompt so the agent can choose | |
| HARAM_GATES = [ | |
| 'ุจุงุจ ุงูู ูู ุนุจุฏุงูุนุฒูุฒ', # South, main entrance, highest traffic | |
| 'ุจุงุจ ุงูู ูู ููุฏ', # North, large capacity | |
| 'ุจุงุจ ุงูุณูุงู ', # East, historic, medium traffic | |
| 'ุจุงุจ ุงููุชุญ', # West, medium capacity | |
| 'ุจุงุจ ุงูุนู ุฑุฉ', # West, Umrah pilgrims | |
| 'ุจุงุจ ุงูู ูู ุนุจุฏุงููู', # South-West, high capacity | |
| 'ุจุงุจ ุงูุตูุง', # East, leads to Safa-Marwa | |
| 'ุจุงุจ ุนูู', # North-East, smaller gate | |
| 'ุจุงุจ ุงูุฒูุงุฏุฉ', # North, overflow gate | |
| 'ุจุงุจ ุงูู ุฑูุฉ', # East, leads to Marwa | |
| ] | |
| def __init__(self, groq_api_key: str): | |
| self.name = 'CoordinatorAgent' | |
| self.aisa_layer = 'Reasoning + Governance (ReAct)' | |
| self._groq_client = Groq(api_key=groq_api_key) | |
| self._active_backend = 'groq' | |
| self._active_model = 'llama-3.3-70b-versatile' | |
| print(f'๐ง [CoordinatorAgent] Ready โ backend=groq model={self._active_model} | ReAct loop') | |
| # โโ Guardrails โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _validate(self, plan: dict, risk_score: float) -> Tuple[dict, list]: | |
| """ | |
| Validate and sanitise model output. | |
| Returns (corrected_plan, list_of_issues). | |
| Issues list is empty when the plan is fully valid. | |
| """ | |
| issues = [] | |
| if not isinstance(plan, dict): | |
| plan = {} | |
| issues.append('GR_C1_invalid_json_object') | |
| # GR-C1: Required fields | |
| for field_name in self.REQUIRED_FIELDS: | |
| if field_name not in plan: | |
| plan[field_name] = 'N/A' | |
| issues.append(f'GR_C1_missing:{field_name}') | |
| # GR-C2: Valid threat level | |
| tl = str(plan.get('threat_level', '')).upper() | |
| if tl not in self.VALID_THREATS: | |
| issues.append(f'GR_C2_invalid_threat:{tl}->HIGH') | |
| plan['threat_level'] = 'HIGH' | |
| else: | |
| plan['threat_level'] = tl | |
| # GR-C3: Confidence in [0, 1] | |
| cs = plan.get('confidence_score', 0) | |
| if not isinstance(cs, (int, float)) or not (0 <= cs <= 1): | |
| issues.append(f'GR_C3_invalid_confidence:{cs}->0.5') | |
| plan['confidence_score'] = 0.5 | |
| # GR-C4: Consistency โ threat_level must match risk_score thresholds | |
| # risk > 0.80 (density_pct > 80%) โ HIGH | |
| # risk > 0.20 (density_pct > 20%) โ MEDIUM | |
| # risk <= 0.20 (density_pct <= 20%) โ LOW | |
| expected = 'HIGH' if risk_score > 0.80 else 'MEDIUM' if risk_score > 0.20 else 'LOW' | |
| tl_current = plan['threat_level'] | |
| if tl_current != expected: | |
| issues.append(f'GR_C4_threat_corrected:{tl_current}->{expected}') | |
| plan['threat_level'] = expected | |
| # GR-C5: Arabic alert fallback | |
| if not str(plan.get('arabic_alert', '')).strip(): | |
| plan['arabic_alert'] = ( | |
| 'ุชูุจูู ุฃู ูู: ููุฑุฌู ู ุฑุงูุจุฉ ููุงุท ุงูุชุฌู ุน ูุงุชุฎุงุฐ ุงูุฅุฌุฑุงุกุงุช ุงูููุงุฆูุฉ ุงููุงุฒู ุฉ.' | |
| ) | |
| issues.append('GR_C5_arabic_fallback') | |
| # GR-C1 extra: immediate_actions must be a non-empty list | |
| ia = plan.get('immediate_actions', []) | |
| if not isinstance(ia, list) or not ia: | |
| plan['immediate_actions'] = ['ุฒูุงุฏุฉ ุงูู ุฑุงูุจุฉ ุงูู ูุฏุงููุฉ', 'ุฅุฑุณุงู ูุญุฏุงุช ุฅูู ููุทุฉ ุงูุงุฒุฏุญุงู '] | |
| issues.append('GR_C1_immediate_actions_fixed') | |
| # GR-C6: selected_gates โ count enforced by threat level | |
| # LOW/P2 โ 0 gates (no action needed) | |
| # MEDIUM โ exactly 1 gate | |
| # HIGH โ exactly 2 gates | |
| # CRITICALโ exactly 2 gates (same as HIGH) | |
| tl_now = plan.get('threat_level', 'LOW') | |
| sg = plan.get('selected_gates', []) | |
| if not isinstance(sg, list): | |
| sg = [] | |
| if tl_now == 'LOW': | |
| plan['selected_gates'] = [] # no action, no gates shown | |
| elif tl_now == 'MEDIUM': | |
| if not sg: | |
| plan['selected_gates'] = ['ุจุงุจ ุงูู ูู ุนุจุฏุงูุนุฒูุฒ'] | |
| issues.append('GR_C6_medium_fallback') | |
| else: | |
| plan['selected_gates'] = sg[:1] # cap at 1 | |
| else: # HIGH or CRITICAL | |
| if len(sg) < 2: | |
| fallback = ['ุจุงุจ ุงูู ูู ุนุจุฏุงูุนุฒูุฒ', 'ุจุงุจ ุงูุณูุงู '] | |
| plan['selected_gates'] = (sg + fallback)[:2] | |
| issues.append('GR_C6_high_padded') | |
| else: | |
| plan['selected_gates'] = sg[:2] # cap at 2 | |
| plan['_guardrail_issues'] = issues | |
| return plan, issues | |
| # โโ ReAct helpers โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _build_prompt( | |
| self, | |
| rr: RiskResult, | |
| decision: Decision, | |
| recent_frames: list, | |
| feedback: str = '' | |
| ) -> str: | |
| """ | |
| Reason step: compose the model prompt. | |
| If feedback is provided (from a previous failed Observe step), | |
| include it so the model can correct itself. | |
| """ | |
| avg_p = np.mean([f.person_count for f in recent_frames]) if recent_frames else 0 | |
| cur_count = recent_frames[-1].person_count if recent_frames else 0 | |
| gates_list = '\n'.join( | |
| f'{i+1}. {g}' for i, g in enumerate(self.HARAM_GATES) | |
| ) | |
| base = ( | |
| 'ุฃูุช ู ูุณู ูุธุงู ุญุงุฑุณ ุงูุญุฑู ูุฅุฏุงุฑุฉ ุณูุงู ุฉ ุงูุญุดูุฏ ูู ุงูู ุณุฌุฏ ุงูุญุฑุงู .\n' | |
| 'ู ูู ุชู: ุฅูุชุงุฌ ุฎุทุฉ ุชุดุบูููุฉ ูุงุถุญุฉ ูู ูุฌุฒุฉ ููู ุดุบููู ุจูุงุกู ุนูู ุจูุงูุงุช ุงูุญุดูุฏ ุงูุญุงููุฉ.\n\n' | |
| f'ู ุณุชูู ุงูุฎุทุฑ : {rr.risk_level} (ุงูุฏุฑุฌุฉ: {rr.risk_score:.3f})\n' | |
| f'ุงูุฃููููุฉ : {decision.priority} (P0=ุทุงุฑุฆ ุญุฑุฌุ P1=ุชุญุฐูุฑ ููุงุฆูุ P2=ู ุฑุงูุจุฉ ุฑูุชูููุฉ)\n' | |
| f'ุงุชุฌุงู ุงูุญุดูุฏ : {rr.trend}\n' | |
| f'ุงูุนุฏุฏ ุงูุญุงูู : {cur_count} ุดุฎุต | ุงูู ุชูุณุท (ุขุฎุฑ 30 ุฅุทุงุฑ): {avg_p:.0f} | ุงูุฐุฑูุฉ: {rr.window_max}\n\n' | |
| 'ุงูุจูุงุจุงุช ุงูู ุชุงุญุฉ ูู ุงูู ุณุฌุฏ ุงูุญุฑุงู โ ุงุฎุชุฑ ุงูุฃูุณุจ ู ููุง:\n' | |
| f'{gates_list}\n\n' | |
| 'ุฅุฑุดุงุฏุงุช ุงุฎุชูุงุฑ ุงูุจูุงุจุงุช:\n' | |
| ' P0 (ุทุงุฑุฆ ุญุฑุฌ): ุงูุชุญ ุจูุงุจุงุช ุงูุฅุฎูุงุก ุนุงููุฉ ุงูุณุนุฉ + ุฃุบูู ุงูู ุฏุงุฎู ุงูู ุฒุฏุญู ุฉ\n' | |
| ' P1 (ุชุญุฐูุฑ) : ูุนูู ุงููุงูุชุงุช ุงูุฅุฑุดุงุฏูุฉ ูุญู ุงูุจูุงุจุงุช ุงูุฃูู ุงุฒุฏุญุงู ุงู\n' | |
| ' P2 (ุฑูุชููู) : ุฑุงูุจ ุงูุจูุงุจุงุช ุงูุฃูุซุฑ ุญุฑูุฉ ููุท\n\n' | |
| ) | |
| if feedback: | |
| base += ( | |
| 'ุชุตุญูุญ ู ุทููุจ โ ุงูุฅุฌุงุจุฉ ุงูุณุงุจูุฉ ุจูุง ู ุดุงูู:\n' | |
| f'{feedback}\n' | |
| 'ุตุญุญ ุฌู ูุน ุงูู ุดุงูู ูุฃุนุฏ ุงูุฅุฌุงุจุฉ.\n\n' | |
| ) | |
| base += ( | |
| 'ููุงุนุฏ ุงูุฅุฎุฑุงุฌ ุงูุตุงุฑู ุฉ:\n' | |
| '- ุฃุฌุจ ููุท ุจู JSON ุฎุงู . ุจุฏูู markdown. ุจุฏูู backticks. ุจุฏูู ุฃู ูุต ูุจูู ุฃู ุจุนุฏู.\n' | |
| '- ุฌู ูุน ุญููู ุงููุต ูุฌุจ ุฃู ุชููู ุจุงููุบุฉ ุงูุนุฑุจูุฉ ุงููุตุญู ุงูุฑุณู ูุฉ ุญุตุฑุงู. ูุง ุชุณุชุฎุฏู ุงูุนุงู ูุฉ ุฃู ุงูููุฌุงุช ุงูู ุญููุฉ ู ุทููุงู. ุงุณุชุฎุฏู ุฃุณููุจุงู ู ูููุงู ุฑุณู ูุงู ูููู ุจุฅุฏุงุฑุฉ ุงูุญุฑู ุงูู ูู.\n' | |
| '- ู ู ููุน ุงุณุชุฎุฏุงู ุตูุบุฉ ุงูุฃู ุฑ ุงูู ุจุงุดุฑ (ุฑุงูุจูุงุ ุชุฃูุฏูุงุ ุงูุชุญูุง). ุงุณุชุฎุฏู ุฏุงุฆู ุงู ุตูุบุฉ ููุฑุฌู / ููุทูุจ / ูููุตู. ู ุซุงู: ููุฑุฌู ู ุฑุงูุจุฉ... ูููุณ ุฑุงูุจูุง...\n' | |
| '- ูุง ุชุณุชุฎุฏู ุนูุงู ุงุช ุงูุงูุชุจุงุณ ุงูู ุฒุฏูุฌุฉ (") ุฏุงุฎู ููู ุงููุต ุงูุนุฑุจู.\n' | |
| '- ูุง ุชุถุน ุฃุณุทุฑุงู ุฌุฏูุฏุฉ (\\n) ุฏุงุฎู ููู ุงููุตูุต ูู JSON.\n\n' | |
| 'ููุงุนุฏ selected_gates โ ุฅูุฒุงู ูุฉ ุจุญุณุจ ู ุณุชูู ุงูุฎุทุฑ:\n' | |
| ' * LOW (P2 ุฑูุชููู) : ูุงุฆู ุฉ ูุงุฑุบุฉ [] โ ูุง ุญุงุฌุฉ ูุฃู ุฅุฌุฑุงุก ุนูู ุงูุจูุงุจุงุช.\n' | |
| ' * MEDIUM (P1 ุชุญุฐูุฑ): ุจูุงุจุฉ ูุงุญุฏุฉ ููุท โ ุงูุฃูุณุจ ููุชูุฌูู ุงูููุงุฆู.\n' | |
| ' * HIGH/CRITICAL (P0): ุจูุงุจุชุงู ููุท โ ุจูุงุจุฉ ุฅุฎูุงุก + ุจูุงุจุฉ ุชุญููู.\n' | |
| ' * ุงุณุชุฎุฏู ุงูุฃุณู ุงุก ุงูุนุฑุจูุฉ ุงูุฏูููุฉ ู ู ุงููุงุฆู ุฉ ุฃุนูุงู ุญุฑูุงู ุจุญุฑู.\n' | |
| ' * ู ู ููุน ู ูุนุงู ุจุงุชุงู ุฅุฏุฑุงุฌ ุฃูุซุฑ ู ู ุจูุงุจุชูู ูู ุฃู ุญุงู โ ูุฐุง ุฎุทุฃ ูุงุฏุญ.\n\n' | |
| '- immediate_actions: 3 ุฅูู 5 ุฅุฌุฑุงุกุงุช ุนุฑุจูุฉ ูุตูุฑุฉ ุชุฐูุฑ ุงูุจูุงุจุงุช ุงูู ุฎุชุงุฑุฉ ุจุงูุงุณู .\n' | |
| '- actions_justification: ุฃูู ู ู 40 ููู ุฉ. ุงุดุฑุญ ูู ุงุฐุง ูุฐู ุงูุจูุงุจุงุช ุจุงูุฐุงุช.\n' | |
| '- arabic_alert: ุฃูู ู ู 15 ููู ุฉ. ุชูุฌูู ุฑุณู ู ุจุงููุตุญู ู ูุฌูู ูู ูุธูู ุงูุฃู ู ูุงูู ุดุบููู ุญุตุฑุงู. ุงุณุชุฎุฏู ุตูุบุฉ ููุฑุฌู/ููุทูุจ ููุท. ู ุซุงู: ููุฑุฌู ุชูุฌูู ุงูุญุดูุฏ ูุญู ุจุงุจ ุงูู ูู ููุฏ ูู ุฑุงูุจุฉ ููุงุท ุงูุชุฌู ุน.\n' | |
| '- executive_summary: ุฃูู ู ู 20 ููู ุฉ. ู ูุฎุต ุงูู ููู ููููุงุฏุฉ.\n\n' | |
| 'ุฃุนุฏ ุจุงูุถุจุท ูุฐุง JSON (ุจุฏูู ุญููู ุฅุถุงููุฉุ ุจุฏูู ูุต ุฅุถุงูู):\n' | |
| '{\n' | |
| ' "threat_level": "HIGH",\n' | |
| ' "executive_summary": "...",\n' | |
| ' "selected_gates": ["ุจุงุจ ..."],\n' | |
| ' "immediate_actions": ["..."],\n' | |
| ' "actions_justification": "...",\n' | |
| ' "arabic_alert": "...",\n' | |
| ' "confidence_score": 0.85\n' | |
| '}' | |
| ) | |
| return base | |
| def _llm_call(self, prompt: str) -> Optional[dict]: | |
| """Act step: call Groq LLM, parse JSON response.""" | |
| raw_text = '' | |
| try: | |
| resp = self._groq_client.chat.completions.create( | |
| model=self._active_model, | |
| messages=[{'role': 'user', 'content': prompt}], | |
| max_tokens=1200, | |
| temperature=0.2, | |
| ) | |
| raw_text = (resp.choices[0].message.content or '').strip() | |
| # โโ Parse JSON โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| try: | |
| return json.loads(raw_text) | |
| except json.JSONDecodeError: | |
| pass | |
| # Brace-counting extractor (handles truncated responses) | |
| start = raw_text.find('{') | |
| if start != -1: | |
| fragment = raw_text[start:] | |
| depth = 0 | |
| best_end = -1 | |
| for i, ch in enumerate(fragment): | |
| if ch == '{': | |
| depth += 1 | |
| elif ch == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| best_end = i | |
| break | |
| if best_end != -1: | |
| return json.loads(fragment[:best_end + 1]) | |
| raise json.JSONDecodeError('No valid JSON block found', raw_text, 0) | |
| except json.JSONDecodeError as e: | |
| print(f' [CoordinatorAgent] JSON parse error: {e}') | |
| print(f' [CoordinatorAgent] Raw LLM response (first 600 chars):\n{raw_text[:600]}') | |
| return {} | |
| except Exception as e: | |
| print(f' [CoordinatorAgent] API error: {e}') | |
| return None | |
| # โโ Public API โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def call(self, rr: RiskResult, decision: Decision, recent_frames: list) -> Optional[dict]: | |
| """ | |
| ReAct loop: Reason โ Act โ Observe, up to MAX_REACT_ITERS times. | |
| Returns validated plan dict or None on unrecoverable error. | |
| """ | |
| print(f'\n๐ง [CoordinatorAgent] {decision.priority} @ frame {rr.frame_id} โ ReAct loop starting...') | |
| feedback = '' | |
| best_plan = None | |
| for iteration in range(1, self.MAX_REACT_ITERS + 1): | |
| print(f' โบ ReAct iteration {iteration}/{self.MAX_REACT_ITERS}') | |
| # โโ Reason โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| prompt = self._build_prompt(rr, decision, recent_frames, feedback) | |
| # โโ Act โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| raw = self._llm_call(prompt) | |
| if raw is None: | |
| print(' [CoordinatorAgent] API unavailable โ aborting ReAct') | |
| return best_plan # return previous best (may be None) | |
| # โโ Observe โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| plan, issues = self._validate(raw, rr.risk_score) | |
| best_plan = plan | |
| if not issues: | |
| print(f' โ [CoordinatorAgent] Plan validated on iteration {iteration}') | |
| plan['_react_iterations'] = iteration | |
| plan['_llm_model'] = f'{self._active_backend}/{self._active_model}' | |
| return plan | |
| feedback = '; '.join(issues) | |
| print(f' โ ๏ธ Issues found ({len(issues)}): {feedback}') | |
| print(' [CoordinatorAgent] Max ReAct iterations reached โ returning best effort') | |
| if best_plan: | |
| best_plan['_react_iterations'] = self.MAX_REACT_ITERS | |
| best_plan['_llm_model'] = f'{self._active_backend}/{self._active_model}' | |
| return best_plan |