File size: 2,783 Bytes
515f392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
LangSmith ์ถ”์ (tracing) ์œ ํ‹ธ๋ฆฌํ‹ฐ ๋ชจ๋“ˆ.

LangGraph ๋…ธ๋“œ ์‹คํ–‰์„ LangSmith์—์„œ ์ถ”์ ํ•˜๊ณ  ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๊ธฐ ์œ„ํ•œ ๋„๊ตฌ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
๊ณต์‹ ๋ฌธ์„œ: https://docs.langchain.com/langsmith/trace-with-langgraph
"""

import os
import logging
import asyncio
from functools import wraps
from typing import Any, Callable
from inspect import iscoroutinefunction

from langsmith import traceable

logger = logging.getLogger(__name__)


def ensure_tracing_enabled() -> bool:
    """
    LangSmith ์ถ”์ ์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์„ค์ •๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
    
    Returns:
        bool: ์ถ”์ ์ด ํ™œ์„ฑํ™”๋˜์–ด ์žˆ์œผ๋ฉด True, ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด False
    """
    required_vars = ["LANGCHAIN_TRACING_V2", "LANGCHAIN_API_KEY"]
    
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    
    if missing_vars:
        logger.warning(
            "LangSmith ์ถ”์ ์ด ๋น„ํ™œ์„ฑํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋ˆ„๋ฝ๋œ ํ™˜๊ฒฝ๋ณ€์ˆ˜: %s",
            ", ".join(missing_vars)
        )
        return False
    
    return True


def trace_node(node_name: str) -> Callable:
    """
    LangGraph ๋…ธ๋“œ ์‹คํ–‰์„ ์ถ”์ ํ•˜๋Š” ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ.
    
    ์ด ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋Š” ๊ฐ ๋…ธ๋“œ์˜ ์ž…๋ ฅ/์ถœ๋ ฅ, ์‹คํ–‰ ์‹œ๊ฐ„, ์—๋Ÿฌ๋ฅผ 
    LangSmith ๋Œ€์‹œ๋ณด๋“œ์— ์ž๋™์œผ๋กœ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
    ๋™๊ธฐ ๋ฐ ๋น„๋™๊ธฐ ํ•จ์ˆ˜ ๋ชจ๋‘ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
    
    Args:
        node_name: LangSmith์— ํ‘œ์‹œ๋  ๋…ธ๋“œ ์ด๋ฆ„
    
    Returns:
        Callable: ๋ฐ์ฝ”๋ ˆ์ดํŠธ๋œ ํ•จ์ˆ˜
    
    Example:
        @trace_node("check_cache")
        async def check_cache_node(state: AgentState) -> AgentState:
            # ๋…ธ๋“œ ๋กœ์ง
            return state
    """
    def decorator(func: Callable) -> Callable:
        # async ํ•จ์ˆ˜์ธ์ง€ ํ™•์ธ
        if iscoroutinefunction(func):
            @wraps(func)
            @traceable(name=node_name, run_type="chain")
            async def async_wrapper(*args, **kwargs) -> Any:
                try:
                    result = await func(*args, **kwargs)
                    return result
                except Exception as e:
                    logger.error("๐Ÿ”ด ๋…ธ๋“œ ์‹คํŒจ: %s - %s", node_name, str(e))
                    raise
            return async_wrapper
        else:
            @wraps(func)
            @traceable(name=node_name, run_type="chain")
            def sync_wrapper(*args, **kwargs) -> Any:
                try:
                    result = func(*args, **kwargs)
                    return result
                except Exception as e:
                    logger.error("๐Ÿ”ด ๋…ธ๋“œ ์‹คํŒจ: %s - %s", node_name, str(e))
                    raise
            return sync_wrapper
    return decorator


# ๋ชจ๋“ˆ import ์‹œ ์ž๋™์œผ๋กœ ์ถ”์  ์„ค์ • ํ™•์ธ
ensure_tracing_enabled()