File size: 7,788 Bytes
bf6dbfa
 
 
 
 
 
 
9d2e886
 
 
 
 
bf6dbfa
 
9d2e886
 
 
 
 
 
 
 
 
 
 
 
 
 
bf6dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d2e886
 
 
 
 
 
 
 
 
 
 
 
bf6dbfa
 
 
 
9d2e886
 
 
 
bf6dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d2e886
 
 
bf6dbfa
 
 
 
 
 
 
 
9d2e886
 
 
 
 
 
 
 
 
 
 
bf6dbfa
 
 
 
9d2e886
 
 
 
bf6dbfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Optional
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from agent.state import AgentState
from rag.retriever import retrieve_documents
from tools.lead_capture import mock_lead_capture
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
import os

_local_llm = None

def get_llm():
    global _local_llm
    if not os.environ.get("OPENAI_API_KEY"):
        if _local_llm is None:

            pipe = pipeline(
                "text-generation",
                model="Qwen/Qwen2.5-0.5B-Instruct",
                max_new_tokens=512,
                device="cpu",
                trust_remote_code=True,
                return_full_text=False
            )
            _local_llm = HuggingFacePipeline(pipeline=pipe)
        return _local_llm
    return ChatOpenAI(model="gpt-4o-mini", temperature=0)

class IntentResponse(BaseModel):
    intent: str = Field(description="The intent of the user. Must be one of: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN")
    confidence: float = Field(description="Confidence score between 0 and 1")

class LeadExtractionResponse(BaseModel):
    user_name: Optional[str] = Field(default=None, description="The name of the user if provided")
    user_email: Optional[str] = Field(default=None, description="The email address of the user if provided")
    creator_platform: Optional[str] = Field(default=None, description="The creator platform (e.g., YouTube, Instagram) if provided")

def detect_intent(state: AgentState) -> AgentState:
    llm = get_llm()
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan."),
        ("user", "{message}")
    ])

    if hasattr(llm, "with_structured_output"):
        chain = prompt | llm.with_structured_output(IntentResponse)
    else:
        from langchain.output_parsers import PydanticOutputParser
        parser = PydanticOutputParser(pydantic_object=IntentResponse)

        prompt = ChatPromptTemplate.from_messages([
            ("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan.\n\n{format_instructions}"),
            ("user", "{message}")
        ])

        chain = prompt | llm | parser

    history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
    context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"

    if hasattr(llm, "with_structured_output"):
        response = chain.invoke({"message": context_message})
    else:
        response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})

    return {"detected_intent": response.intent}

def handle_greeting(state: AgentState) -> AgentState:
    return {"response": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"}

def handle_unknown(state: AgentState) -> AgentState:
    return {"response": "I'm not quite sure how to help with that. Could you clarify your question about AutoStream?"}

def retrieve_knowledge(state: AgentState) -> AgentState:
    docs = retrieve_documents(state["current_message"])
    return {"retrieved_documents": docs}

def generate_rag_response(state: AgentState) -> AgentState:
    llm = get_llm()
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful sales assistant for AutoStream. Answer the user's question based strictly on the following retrieved knowledge:\n\n{context}\n\nIf the answer is not in the context, say you don't know."),
        ("user", "{message}")
    ])

    context = "\n\n".join(state.get("retrieved_documents", []))
    chain = prompt | llm

    response = chain.invoke({
        "context": context,
        "message": state["current_message"]
    })

    content = response.content if hasattr(response, "content") else str(response)

    return {"response": content}

def process_lead(state: AgentState) -> AgentState:
    llm = get_llm()

    extract_prompt = ChatPromptTemplate.from_messages([
        ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
        ("user", "{message}")
    ])

    if hasattr(llm, "with_structured_output"):
        extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
    else:
        from langchain.output_parsers import PydanticOutputParser
        parser = PydanticOutputParser(pydantic_object=LeadExtractionResponse)
        extract_prompt = ChatPromptTemplate.from_messages([
            ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found.\n\n{format_instructions}"),
            ("user", "{message}")
        ])
        extract_chain = extract_prompt | llm | parser

    history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
    context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"

    if hasattr(llm, "with_structured_output"):
        extracted = extract_chain.invoke({"message": context_message})
    else:
        extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})

    updates = {}
    if extracted.user_name and not state.get("user_name"):
        updates["user_name"] = extracted.user_name
    if extracted.user_email and not state.get("user_email"):
        updates["user_email"] = extracted.user_email
    if extracted.creator_platform and not state.get("creator_platform"):
        updates["creator_platform"] = extracted.creator_platform

    current_name = updates.get("user_name", state.get("user_name"))
    current_email = updates.get("user_email", state.get("user_email"))
    current_platform = updates.get("creator_platform", state.get("creator_platform"))

    if not current_name:
        updates["response"] = "Great! I can help with that. Could I have your name?"
        return updates
    elif not current_email:
        updates["response"] = f"Thanks {current_name}! What is your email address?"
        return updates
    elif not current_platform:
        updates["response"] = "Got it. And what creator platform do you primarily use (e.g., YouTube, TikTok)?"
        return updates
    else:
        updates["lead_ready"] = True
        return updates

def execute_tool(state: AgentState) -> AgentState:
    if state.get("lead_ready") and state.get("user_name") and state.get("user_email") and state.get("creator_platform"):
        mock_lead_capture(
            state["user_name"],
            state["user_email"],
            state["creator_platform"]
        )
        return {"response": f"Thanks {state['user_name']}! I've successfully collected your information for your {state['creator_platform']} channel. Our team will reach out to {state['user_email']} shortly."}
    else:
        return {"response": "Error: Tried to execute lead capture tool without all required fields."}