Spaces:
Sleeping
Sleeping
File size: 7,788 Bytes
bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa 9d2e886 bf6dbfa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | from typing import Optional
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from agent.state import AgentState
from rag.retriever import retrieve_documents
from tools.lead_capture import mock_lead_capture
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
import os
_local_llm = None
def get_llm():
global _local_llm
if not os.environ.get("OPENAI_API_KEY"):
if _local_llm is None:
pipe = pipeline(
"text-generation",
model="Qwen/Qwen2.5-0.5B-Instruct",
max_new_tokens=512,
device="cpu",
trust_remote_code=True,
return_full_text=False
)
_local_llm = HuggingFacePipeline(pipeline=pipe)
return _local_llm
return ChatOpenAI(model="gpt-4o-mini", temperature=0)
class IntentResponse(BaseModel):
intent: str = Field(description="The intent of the user. Must be one of: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN")
confidence: float = Field(description="Confidence score between 0 and 1")
class LeadExtractionResponse(BaseModel):
user_name: Optional[str] = Field(default=None, description="The name of the user if provided")
user_email: Optional[str] = Field(default=None, description="The email address of the user if provided")
creator_platform: Optional[str] = Field(default=None, description="The creator platform (e.g., YouTube, Instagram) if provided")
def detect_intent(state: AgentState) -> AgentState:
llm = get_llm()
prompt = ChatPromptTemplate.from_messages([
("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan."),
("user", "{message}")
])
if hasattr(llm, "with_structured_output"):
chain = prompt | llm.with_structured_output(IntentResponse)
else:
from langchain.output_parsers import PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=IntentResponse)
prompt = ChatPromptTemplate.from_messages([
("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan.\n\n{format_instructions}"),
("user", "{message}")
])
chain = prompt | llm | parser
history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
if hasattr(llm, "with_structured_output"):
response = chain.invoke({"message": context_message})
else:
response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
return {"detected_intent": response.intent}
def handle_greeting(state: AgentState) -> AgentState:
return {"response": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"}
def handle_unknown(state: AgentState) -> AgentState:
return {"response": "I'm not quite sure how to help with that. Could you clarify your question about AutoStream?"}
def retrieve_knowledge(state: AgentState) -> AgentState:
docs = retrieve_documents(state["current_message"])
return {"retrieved_documents": docs}
def generate_rag_response(state: AgentState) -> AgentState:
llm = get_llm()
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful sales assistant for AutoStream. Answer the user's question based strictly on the following retrieved knowledge:\n\n{context}\n\nIf the answer is not in the context, say you don't know."),
("user", "{message}")
])
context = "\n\n".join(state.get("retrieved_documents", []))
chain = prompt | llm
response = chain.invoke({
"context": context,
"message": state["current_message"]
})
content = response.content if hasattr(response, "content") else str(response)
return {"response": content}
def process_lead(state: AgentState) -> AgentState:
llm = get_llm()
extract_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
("user", "{message}")
])
if hasattr(llm, "with_structured_output"):
extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
else:
from langchain.output_parsers import PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=LeadExtractionResponse)
extract_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found.\n\n{format_instructions}"),
("user", "{message}")
])
extract_chain = extract_prompt | llm | parser
history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
if hasattr(llm, "with_structured_output"):
extracted = extract_chain.invoke({"message": context_message})
else:
extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
updates = {}
if extracted.user_name and not state.get("user_name"):
updates["user_name"] = extracted.user_name
if extracted.user_email and not state.get("user_email"):
updates["user_email"] = extracted.user_email
if extracted.creator_platform and not state.get("creator_platform"):
updates["creator_platform"] = extracted.creator_platform
current_name = updates.get("user_name", state.get("user_name"))
current_email = updates.get("user_email", state.get("user_email"))
current_platform = updates.get("creator_platform", state.get("creator_platform"))
if not current_name:
updates["response"] = "Great! I can help with that. Could I have your name?"
return updates
elif not current_email:
updates["response"] = f"Thanks {current_name}! What is your email address?"
return updates
elif not current_platform:
updates["response"] = "Got it. And what creator platform do you primarily use (e.g., YouTube, TikTok)?"
return updates
else:
updates["lead_ready"] = True
return updates
def execute_tool(state: AgentState) -> AgentState:
if state.get("lead_ready") and state.get("user_name") and state.get("user_email") and state.get("creator_platform"):
mock_lead_capture(
state["user_name"],
state["user_email"],
state["creator_platform"]
)
return {"response": f"Thanks {state['user_name']}! I've successfully collected your information for your {state['creator_platform']} channel. Our team will reach out to {state['user_email']} shortly."}
else:
return {"response": "Error: Tried to execute lead capture tool without all required fields."}
|