File size: 2,041 Bytes
356e05f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import traceback
from agents import RunContextWrapper, Runner, GuardrailFunctionOutput, input_guardrail
from guardrails.input_guardrails import guardrail_agent

@input_guardrail
async def guardrail_input_function(ctx: RunContextWrapper, agent, user_input: str):
    try:
        result = await Runner.run(
            guardrail_agent, 
            input=user_input, 
            context=ctx.context
        )
        
        # Check if result has the expected structure
        if not result or not hasattr(result, 'final_output'):
            print(f"Warning: Guardrail agent returned unexpected result: {result}")
            # Allow the query to proceed if guardrail fails
            return GuardrailFunctionOutput(
                output_info=None,
                tripwire_triggered=False
            )
        
        final_output = result.final_output
        
        # Check if final_output has the expected attribute
        if not hasattr(final_output, 'is_query_about_jobobike'):
            print(f"Warning: Guardrail output missing is_query_about_jobobike attribute: {final_output}")
            return GuardrailFunctionOutput(
                output_info=final_output,
                tripwire_triggered=False
            )
        
        return GuardrailFunctionOutput(
            output_info=final_output,
            tripwire_triggered=not final_output.is_query_about_jobobike
        )
    except Exception as e:
        error_str = str(e)
        # Check if it's an API key error
        if "API key" in error_str or "expired" in error_str.lower() or "INVALID_ARGUMENT" in error_str:
            print(f"API key error in guardrail - allowing query through: {error_str[:100]}")
        else:
            print(f"Error in guardrail_input_function: {error_str[:200]}")
            print(traceback.format_exc())
        # Always allow the query to proceed if guardrail fails (especially for API errors)
        return GuardrailFunctionOutput(
            output_info=None,
            tripwire_triggered=False
        )