safety_agent / app.py
Neo111x's picture
added app.py
bd1a0a6 verified
import gradio as gr
import torch, gc, time
from autogen import UserProxyAgent
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# --------------------
# Config
# --------------------
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_4BIT = False
RELOAD_PER_QUERY = True
quantization_config = None
if USE_4BIT:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
llm_int8_threshold=6.0,
)
# --------------------
# Load / unload utils
# --------------------
def load_model_and_tokenizer():
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
if tok.pad_token is None:
tok.pad_token = tok.eos_token if tok.eos_token else "[PAD]"
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
quantization_config=quantization_config,
)
return mdl, tok
def free_gpu(*objs):
for o in objs:
try: del o
except: pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --------------------
# Local LLM agent
# --------------------
from autogen import ConversableAgent
class LocalLLMAgent(ConversableAgent):
def __init__(self, name, model, tokenizer, **kwargs):
super().__init__(name, **kwargs)
self.model = model
self.tokenizer = tokenizer
def generate_reply(self, messages, **kwargs):
buf = []
if self.system_message:
buf.append(f"SYSTEM: {self.system_message}")
for m in messages:
role = m.get("role", "user").upper()
buf.append(f"{role}: {m['content']}")
buf.append("ASSISTANT:")
prompt = "\n".join(buf)
inputs = self.tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
for cp in ["\nUSER:", "\nSYSTEM:", "\nASSISTANT:"]:
if cp in text:
text = text.split(cp)[0].strip()
return text
# --------------------
# Audit agent
# --------------------
class LocalAuditAgent(ConversableAgent):
def __init__(self, name, model, tokenizer, **kwargs):
super().__init__(name, **kwargs)
self.model = model
self.tokenizer = tokenizer
def generate_reply(self, messages, **kwargs):
last = messages[-1]["content"]
audit_prompt = (
f"{self.system_message}\n\n"
f"Text: {last}\n\n"
"Answer with ONE WORD ONLY: safe OR block"
)
inputs = self.tokenizer(audit_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=3,
temperature=0.0,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
decision = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip().lower()
if "block" in decision:
return "block"
if "safe" in decision:
return "safe"
return "block"
# --------------------
# Pipeline
# --------------------
def make_agents():
model, tok = load_model_and_tokenizer()
llm_agent = LocalLLMAgent(
"LLM_Agent", model, tok,
system_message="You generate helpful, safe, compliant answers."
)
audit_agent = LocalAuditAgent(
"Audit_Agent", model, tok,
system_message="Classify ONLY as 'safe' or 'block'. If malware, hacking, or illegal activity => block. Else => safe."
)
return model, tok, llm_agent, audit_agent
def process_query(user_prompt):
model, tok, llm_agent, audit_agent = make_agents()
user_decision = audit_agent.generate_reply([{"role":"user","content":user_prompt}])
if user_decision == "block":
free_gpu(model, tok, llm_agent, audit_agent)
return user_prompt, "This content is blocked by the audit agent", "πŸ”΄πŸš« block"
llm_reply = llm_agent.generate_reply([{"role":"user","content":user_prompt}])
reply_decision = audit_agent.generate_reply([{"role":"user","content":llm_reply}])
icon = "πŸŸ’βœ… safe" if reply_decision=="safe" else "πŸ”΄πŸš« block"
free_gpu(model, tok, llm_agent, audit_agent)
return user_prompt, llm_reply, icon
# --------------------
# Gradio App
# --------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ” Autogen LLM + Audit Agent Playground")
with gr.Row():
with gr.Column():
inp = gr.Textbox(label="Enter Prompt")
btn = gr.Button("Run πŸš€")
status = gr.Label(value="⏳ Waiting...")
with gr.Column():
out_prompt = gr.Textbox(label="Prompt")
out_llm = gr.Textbox(label="LLM Output")
out_audit = gr.Textbox(label="Audit Decision")
def wrapped(prompt):
status_val = "⚑ Loading & inferring..."
yield None, None, None, status_val
time.sleep(0.3)
p, r, a = process_query(prompt)
yield p, r, a, "βœ… Done"
btn.click(
wrapped,
inputs=inp,
outputs=[out_prompt, out_llm, out_audit, status]
)
if __name__ == "__main__":
demo.launch(debug=True)