Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch, gc, time | |
| from autogen import UserProxyAgent | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| # -------------------- | |
| # Config | |
| # -------------------- | |
| MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| USE_4BIT = False | |
| RELOAD_PER_QUERY = True | |
| quantization_config = None | |
| if USE_4BIT: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=False, | |
| llm_int8_threshold=6.0, | |
| ) | |
| # -------------------- | |
| # Load / unload utils | |
| # -------------------- | |
| def load_model_and_tokenizer(): | |
| tok = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token if tok.eos_token else "[PAD]" | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| quantization_config=quantization_config, | |
| ) | |
| return mdl, tok | |
| def free_gpu(*objs): | |
| for o in objs: | |
| try: del o | |
| except: pass | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # -------------------- | |
| # Local LLM agent | |
| # -------------------- | |
| from autogen import ConversableAgent | |
| class LocalLLMAgent(ConversableAgent): | |
| def __init__(self, name, model, tokenizer, **kwargs): | |
| super().__init__(name, **kwargs) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def generate_reply(self, messages, **kwargs): | |
| buf = [] | |
| if self.system_message: | |
| buf.append(f"SYSTEM: {self.system_message}") | |
| for m in messages: | |
| role = m.get("role", "user").upper() | |
| buf.append(f"{role}: {m['content']}") | |
| buf.append("ASSISTANT:") | |
| prompt = "\n".join(buf) | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| do_sample=True, | |
| ) | |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] | |
| text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| for cp in ["\nUSER:", "\nSYSTEM:", "\nASSISTANT:"]: | |
| if cp in text: | |
| text = text.split(cp)[0].strip() | |
| return text | |
| # -------------------- | |
| # Audit agent | |
| # -------------------- | |
| class LocalAuditAgent(ConversableAgent): | |
| def __init__(self, name, model, tokenizer, **kwargs): | |
| super().__init__(name, **kwargs) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def generate_reply(self, messages, **kwargs): | |
| last = messages[-1]["content"] | |
| audit_prompt = ( | |
| f"{self.system_message}\n\n" | |
| f"Text: {last}\n\n" | |
| "Answer with ONE WORD ONLY: safe OR block" | |
| ) | |
| inputs = self.tokenizer(audit_prompt, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, | |
| max_new_tokens=3, | |
| temperature=0.0, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| new_tokens = out[0][inputs["input_ids"].shape[1]:] | |
| decision = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip().lower() | |
| if "block" in decision: | |
| return "block" | |
| if "safe" in decision: | |
| return "safe" | |
| return "block" | |
| # -------------------- | |
| # Pipeline | |
| # -------------------- | |
| def make_agents(): | |
| model, tok = load_model_and_tokenizer() | |
| llm_agent = LocalLLMAgent( | |
| "LLM_Agent", model, tok, | |
| system_message="You generate helpful, safe, compliant answers." | |
| ) | |
| audit_agent = LocalAuditAgent( | |
| "Audit_Agent", model, tok, | |
| system_message="Classify ONLY as 'safe' or 'block'. If malware, hacking, or illegal activity => block. Else => safe." | |
| ) | |
| return model, tok, llm_agent, audit_agent | |
| def process_query(user_prompt): | |
| model, tok, llm_agent, audit_agent = make_agents() | |
| user_decision = audit_agent.generate_reply([{"role":"user","content":user_prompt}]) | |
| if user_decision == "block": | |
| free_gpu(model, tok, llm_agent, audit_agent) | |
| return user_prompt, "This content is blocked by the audit agent", "π΄π« block" | |
| llm_reply = llm_agent.generate_reply([{"role":"user","content":user_prompt}]) | |
| reply_decision = audit_agent.generate_reply([{"role":"user","content":llm_reply}]) | |
| icon = "π’β safe" if reply_decision=="safe" else "π΄π« block" | |
| free_gpu(model, tok, llm_agent, audit_agent) | |
| return user_prompt, llm_reply, icon | |
| # -------------------- | |
| # Gradio App | |
| # -------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Autogen LLM + Audit Agent Playground") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox(label="Enter Prompt") | |
| btn = gr.Button("Run π") | |
| status = gr.Label(value="β³ Waiting...") | |
| with gr.Column(): | |
| out_prompt = gr.Textbox(label="Prompt") | |
| out_llm = gr.Textbox(label="LLM Output") | |
| out_audit = gr.Textbox(label="Audit Decision") | |
| def wrapped(prompt): | |
| status_val = "β‘ Loading & inferring..." | |
| yield None, None, None, status_val | |
| time.sleep(0.3) | |
| p, r, a = process_query(prompt) | |
| yield p, r, a, "β Done" | |
| btn.click( | |
| wrapped, | |
| inputs=inp, | |
| outputs=[out_prompt, out_llm, out_audit, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |