Spaces:
Runtime error
Runtime error
| import os | |
| from transformers import AutoProcessor, Llama4ForConditionalGeneration | |
| from PIL import Image | |
| import requests | |
| import torch | |
| import gradio as gr | |
| import re | |
| import spaces | |
| model_id = "meta-llama/Llama-Guard-4-12B" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Llama4ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| device_map="cuda:0", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| labels_dict = { | |
| "S1": "Violent Crimes", | |
| "S2": "Non-Violent Crimes", | |
| "S3": "Sex Crimes", | |
| "S4": "Child Exploitation", | |
| "S5": "Defamation", | |
| "S6": "Specialized Advice", | |
| "S7": "Privacy", | |
| "S8": "Intellectual Property", | |
| "S9": "Indiscriminate Weapons", | |
| "S10": "Hate", | |
| "S11": "Self-Harm", | |
| "S12": "Sexual Content", | |
| "S13": "Elections", | |
| } | |
| def infer(image, text_input, model_output, exclude_categories): | |
| if image is None and text_input: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": text_input}] | |
| }, | |
| ] | |
| elif image is not None and text_input: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": text_input}] | |
| } | |
| ] | |
| messages[0]["content"].append({"type": "image", "url": image}) | |
| else: | |
| return "Please provide at least text input." | |
| if model_output: | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": model_output}] | |
| } | |
| ) | |
| print("messages", messages ) | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| exclude_category_keys=exclude_categories, | |
| ).to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| do_sample=False, | |
| ) | |
| response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0] | |
| if "unsafe" in response: | |
| match = re.search(r'S(\d+)', response) | |
| if match: | |
| s_number = f"S{match.group(1)}" | |
| category = labels_dict.get(s_number, "Unknown Category") | |
| response = f"This content is unsafe under category {category}" | |
| if "safe<|eot|>" in response: | |
| response = "This conversation is safe." | |
| return messages, response | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Llama Guard 4 for Multimodal Safety") | |
| gr.Markdown( | |
| "Llama Guard 4 is a safety moderation model for both large language and vision language models. " | |
| "It can detect unsafe images and text inputs. To use it, simply input text or images along with text." | |
| "You can also provide hypothetical LLM outputs separately, as this model can take in entire conversation." | |
| ) | |
| with gr.Column(): | |
| image = gr.Image(label="Image Input (Optional)", type="filepath") | |
| text_input = gr.Textbox(label="Text Input") | |
| model_output = gr.Textbox(label="Model Output") | |
| with gr.Accordion("Exclude Safety Categories", open=False): | |
| exclude_categories = gr.CheckboxGroup( | |
| choices=[k for k in labels_dict.keys()], | |
| label="Select categories to EXCLUDE from moderation", | |
| info="Selected categories will NOT be flagged." | |
| ) | |
| btn = gr.Button("Submit") | |
| complete_conversation = gr.Textbox(label="Complete Conversation Provided by User") | |
| outputs = [complete_conversation, gr.Textbox(label="Safety Category")] | |
| btn.click( | |
| fn=infer, | |
| inputs=[image, text_input, model_output, exclude_categories], | |
| outputs=outputs | |
| ) | |
| gr.Examples( | |
| examples=[[None, "How to make a bomb?", "", []], ["./fruit_knife.png", "How to use this?", "", []], [None, "How to make a bomb?", "Sorry I can't respond to this.", ["S1"]]], | |
| inputs=[image, text_input, model_output, exclude_categories], | |
| outputs=outputs | |
| ) | |
| demo.launch() |