File size: 2,666 Bytes
849c690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Test script for the LangGraph agent pipeline.
Runs several queries with a short wait between them to verify the full flow.
Requires gemini_api in environment for real LLM calls; otherwise only tests prepare_generation (no API).
"""

import os
import time
from rag_engine import RAGEngine
from agent import build_agent_graph, run_stream


def main():
    print("Loading RAG Engine and building agent graph...")
    engine = RAGEngine()
    graph = build_agent_graph(engine)
    print("OK.\n")

    api_key = os.environ.get("gemini_api")
    if not api_key:
        print("鈿狅笍  gemini_api not set. Testing only prepare_generation (no LLM calls).\n")
        test_queries = [
            "Tell me about the Audi RS3",
            "Compare Audi RS3 vs Hyundai Elantra N",
            "诪讛 讚注转讱 注诇 BMW X5?",  # should trigger refusal
        ]
        for i, query in enumerate(test_queries, 1):
            print(f"--- Test {i}: prepare_generation ---")
            print(f"Query: {query!r}")
            refusal, sys_p, user_p, steps = engine.prepare_generation(query)
            if refusal:
                print(f"Refusal (expected for unsupported car): {refusal[:150]}...")
            else:
                print(f"Steps: {len(steps)}; system_prompt length: {len(sys_p or '')}; user_prompt length: {len(user_p or '')}")
            print()
        print("Done (prepare_generation only). Set gemini_api to run full agent.")
        return

    test_queries = [
        "Tell me about the Audi RS3",
        "Compare Audi RS3 vs Hyundai Elantra N",
        "诪讛 讛讬转专讜谞讜转 砖诇 拽讬讛 EV9?",
        "诪讛 讚注转讱 注诇 BMW X5?",  # should trigger refusal (unsupported model)
    ]
    wait_seconds = 8

    for i, query in enumerate(test_queries, 1):
        print(f"--- Test {i}/{len(test_queries)} ---")
        print(f"Query: {query!r}")
        last_output = None
        step_count = 0
        try:
            for out in run_stream(engine, graph, query, api_key):
                last_output = out
                step_count += 1
            if last_output:
                preview = last_output[:400] + "..." if len(last_output) > 400 else last_output
                print(f"Steps yielded: {step_count}; final length: {len(last_output)}")
                print(f"Final preview:\n{preview}\n")
            else:
                print("No output received.\n")
        except Exception as e:
            print(f"Error: {e}\n")
        if i < len(test_queries):
            print(f"Waiting {wait_seconds}s before next query...")
            time.sleep(wait_seconds)

    print("All tests finished.")


if __name__ == "__main__":
    main()