File size: 2,048 Bytes
18b52eb
 
 
 
 
 
2cb39a9
18b52eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cb39a9
18b52eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

import os
import json
from agents import Agent, OpenAIChatCompletionsModel, Runner, GuardrailFunctionOutput
from pydantic import BaseModel
from openai import AsyncOpenAI
from core.model import get_model_client

class ValidatedOutput(BaseModel):
    is_valid: bool
    reasoning: str

input_validation_agent = Agent(
    name="Guardrail Input Validation Agent",
    instructions="""
        You are a highly efficient and specialized **Agent** 🌐. Your sole function is to validate the user inputs.
        
        ## Core Directives & Priorities
        1. You should flag if the user uses unparaliamentary language ONLY.
        2. You MUST give reasoning for the same.
        
        ## Rules
        - If it contains any of these, mark `"is_valid": false` and explain **why** in `"reasoning"`.
        - Otherwise, mark `"is_valid": true` with reasoning like "The input follows respectful communication guidelines."


        ## Output Format (MANDATORY)
        * Return a JSON object with the following structure:
        {
            "is_valid": <boolean>,
            "reasoning": <string>
        }
    """,
    model=get_model_client(),
    output_type=ValidatedOutput,
)
input_validation_agent.description = "A guardrail agent that validates user input for unparliamentary language."

async def input_validation_guardrail(ctx, agent, input_data):
    result = await Runner.run(input_validation_agent, input_data, context=ctx.context)
    raw_output = result.final_output

    # Handle different return shapes gracefully
    if isinstance(raw_output, ValidatedOutput):
        final_output = raw_output
        print("Parsed ValidatedOutput:", final_output)
    else:
        final_output = ValidatedOutput(
            is_valid=False,
            reasoning=f"Unexpected output type: {type(raw_output)}"
        )

    return GuardrailFunctionOutput(
        output_info=final_output,
        tripwire_triggered=not final_output.is_valid,
    )

__all__ = ["input_validation_agent", "input_validation_guardrail", "ValidatedOutput"]