risk-atlas-nexus / executor.py
ingelise's picture
Update to the graph view
f455c3d verified
raw
history blame
11 kB
import os
import pandas as pd
import gradio as gr
import datetime
from pathlib import Path
import json
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus
from functools import lru_cache, wraps
from dotenv import load_dotenv
load_dotenv(override=True)
# Load the taxonomies
ran = RiskAtlasNexus() # type: ignore
def clear_previous_risks():
return gr.Markdown("""<h2> Potential Risks </h2> """), [], gr.Dataset(samples=[],
sample_labels=[],
samples_per_page=50, visible=False), gr.DownloadButton("Download JSON", visible=False, ), "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" "), gr.Markdown(" "),
def clear_previous_mitigations():
return "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" "), gr.Markdown(" ")
def generate_subgraph(usecase, risk):
lines =[f'```mermaid\n', '---\n'
'config:\n'
' theme: mc\n'
' layout: dagre\n'
' look: classic\n'
'---\n'
'flowchart TB\n']
lines.append(f'uc_173@{{ label: "{usecase}" }} -- subClassOf --> AISystem["AISystem"]\n')
lines.append(f'uc_173 -- hasRisk --> Risk2["{risk.name}"]\n')
lines.append(f'Risk2 -- isPartOf --> {risk.isPartOf}\n')
lines.append(f'Risk2 -- isDefinedByTaxonomy --> {risk.isDefinedByTaxonomy}\n')
# add related risks
rrs = ran.get_related_risks(id=risk.id)
if len(rrs) > 0:
r_risks = ''
for rr in rrs:
r_risks = r_risks + f'{rr.name}, '
lines.append(f'Risk2 -- hasRelatedRisks --> Risk3["{r_risks}"]\n')
# add related evals
revals = ran.get_related_evaluations(risk_id=risk.id)
if len(revals) > 0:
r_evals =''
for reval in revals:
r_evals = r_evals + f'{reval.name}, '
lines.append(f'Risk2 -- hasAiEvaluations --> Risk4["{r_evals[:100]}"] \n')
# add related mitigations
rmits = get_controls_and_actions(risk.id, risk.isDefinedByTaxonomy)
if len(rmits) > 0:
r_mits = ', '.join(rmits)
lines.append(f'Risk2 -- hasMitigations --> Risk5["{r_mits[:100]}"] \n')
lines.append(f"```")
diagram_string = "".join(lines)
return gr.Markdown(value = diagram_string)
def custom_lru_cache(maxsize=128, exclude_values=(None,[],[[]])):
"""
Make the LRU cache not cache result when empty result was returned
"""
def decorator(func):
cached_func = lru_cache(maxsize=maxsize)(func)
@wraps(func)
def wrapper(*args, **kwargs):
result = cached_func(*args, **kwargs)
# check for empty df of risks
if result[2].constructor_args["samples"] in exclude_values:
return func(*args, **kwargs)
return result
return wrapper
return decorator
@custom_lru_cache(exclude_values=(None, []))
def risk_identifier(usecase: str,
model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct",
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
downloadable = False
inference_engine = WMLInferenceEngine(
model_name_or_path= model_name_or_path,
credentials={
"api_key": os.environ["WML_API_KEY"],
"api_url": os.environ["WML_API_URL"],
"project_id": os.environ["WML_PROJECT_ID"],
},
parameters=WMLInferenceEngineParams(
max_new_tokens=1000, decoding_method="greedy", repetition_penalty=1
), # type: ignore
)
risks_a = ran.identify_risks_from_usecases(# type: ignore
usecases=[usecase],
inference_engine=inference_engine,
taxonomy=taxonomy,
zero_shot_only=True,
max_risk=5
)
risks = risks_a[0]
sample_labels = [r.name if r else r.id for r in risks]
out_sec = gr.Markdown("""<h2> Potential Risks </h2> """)
# write out a JSON
data = {'time': str(datetime.datetime.now(datetime.timezone.utc)),
'intent': usecase,
'model': model_name_or_path,
'taxonomy': taxonomy,
'risks': [json.loads(r.json()) for r in risks]
}
file_path = Path("static/download.json")
with open(file_path, mode='w') as f:
f.write(json.dumps(data, indent=4))
downloadable = True
# return out_df
return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks],
sample_labels=sample_labels,
samples_per_page=50, visible=True, label="Estimated by an LLM."), gr.DownloadButton("Download JSON", "static/download.json", visible=(downloadable and len(risks) > 0))
def get_controls_and_actions(riskid, taxonomy):
selected_risk = ran.get_risk(id=riskid)
related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
action_ids = []
control_ids =[]
intrinsic_ids=[]
if taxonomy == "ibm-risk-atlas":
# look for actions associated with related risks
if related_risk_ids:
for i in related_risk_ids:
rai = ran.get_related_actions(id=i)
if rai:
action_ids += rai
rac = ran.get_related_risk_controls(id=i)
if rac:
control_ids += rac
ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
if ran_intrinsics:
intrinsic_ids += ran_intrinsics
else:
action_ids = []
control_ids = []
intrinsic_ids=[]
else:
# Use only actions related to primary risks
action_ids = ran.get_related_actions(id=riskid)
control_ids = ran.get_related_risk_controls(id=riskid)
intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] + [ran.get_intrinsic(i.id).name for i in intrinsic_ids]#type: ignore
@lru_cache
def mitigations(usecase: str, riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr.DataFrame, gr.DataFrame, gr.Markdown, gr.Markdown]:
"""
For a specific risk (riskid), returns
(a) a risk description
(b) related risks - as a dataset
(c) mitigations
(d) related AI evaluations
(e) A subgraph of risk to mitigations
"""
try:
selected_risk = ran.get_risk(id=riskid)
risk_desc = selected_risk.description # type: ignore
risk_sec = f"<h3>Description: </h3> {risk_desc}"
except AttributeError:
risk_sec = ""
related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
related_ai_eval_ids = [ai_eval.id for ai_eval in ran.get_related_evaluations(risk_id=riskid)]
action_ids = []
control_ids =[]
intrinsic_ids=[]
if taxonomy == "ibm-risk-atlas":
# look for actions associated with related risks
if related_risk_ids:
for i in related_risk_ids:
ran_actions = ran.get_related_actions(id=i)
if ran_actions:
action_ids += ran_actions
ran_controls = ran.get_related_risk_controls(id=i)
if ran_controls:
control_ids += ran_controls
ran_intrinsics = ran.get_related_intrinsics(risk_id=i)
if ran_intrinsics:
intrinsic_ids += ran_intrinsics
else:
action_ids = []
control_ids = []
intrinsic_ids=[]
else:
# Use only actions related to primary risks
action_ids = ran.get_related_actions(id=riskid)
control_ids = ran.get_related_risk_controls(id=riskid)
intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid)
# Sanitize outputs
if not related_risk_ids:
label = "No related risks found."
samples = None
sample_labels = None
else:
label = f"Risks from other taxonomies related to {riskid}"
samples = related_risk_ids
sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
if not action_ids and not control_ids and not intrinsic_ids:
alabel = "No mitigations found."
asamples = None
asample_labels = None
mitdf = pd.DataFrame()
else:
alabel = f"Mitigation actions and controls related to risk {riskid}."
asamples = action_ids
asamples_ctl = control_ids
asamples_int = intrinsic_ids
asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).description for i in asamples_int]# type: ignore
asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).name for i in asamples_int] #type: ignore
asample_types = ["Action" for i in asamples] + ["Control" for i in asamples_ctl] + ["Intrinsic" for i in asamples_int]
mitdf = pd.DataFrame({"Type": asample_types, "Mitigation": asample_name, "Description": asample_labels})
if not related_ai_eval_ids:
blabel = "No related AI evaluations found."
bsamples = None
bsample_labels = None
aievalsdf = pd.DataFrame()
else:
blabel = f"AI Evaluations related to {riskid}"
bsamples = related_ai_eval_ids
bsample_labels = [ran.get_evaluation(i).description for i in bsamples] # type: ignore
bsample_name = [ran.get_evaluation(i).name for i in bsamples] #type: ignore
aievalsdf = pd.DataFrame({"AI Evaluation": bsample_name, "Description": bsample_labels})
status = gr.Markdown(" ") if len(mitdf) > 0 else gr.Markdown("No mitigations found.")
fig = gr.Markdown(" ") if not selected_risk else generate_subgraph(usecase, selected_risk)
return (gr.Markdown(risk_sec),
gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True),
gr.DataFrame(aievalsdf, wrap=True, show_copy_button=True, show_search="search", label=blabel, visible=True),
status, fig)