social-agent / agent /nodes.py
Priyansh Saxena
feat: add local llm
4c68cfa
from typing import Optional
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from agent.state import AgentState
from rag.retriever import retrieve_documents
from tools.lead_capture import mock_lead_capture
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline
import os
_local_llm = None
def get_llm():
global _local_llm
if not os.environ.get("OPENAI_API_KEY"):
if _local_llm is None:
pipe = pipeline(
"text-generation",
model="Qwen/Qwen2.5-0.5B-Instruct",
max_new_tokens=512,
device="cpu",
trust_remote_code=True,
return_full_text=False
)
_local_llm = HuggingFacePipeline(pipeline=pipe)
return _local_llm
return ChatOpenAI(model="gpt-4o-mini", temperature=0)
class IntentResponse(BaseModel):
intent: str = Field(description="The intent of the user. Must be one of: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN")
confidence: float = Field(description="Confidence score between 0 and 1")
class LeadExtractionResponse(BaseModel):
user_name: Optional[str] = Field(default=None, description="The name of the user if provided")
user_email: Optional[str] = Field(default=None, description="The email address of the user if provided")
creator_platform: Optional[str] = Field(default=None, description="The creator platform (e.g., YouTube, Instagram) if provided")
def detect_intent(state: AgentState) -> AgentState:
llm = get_llm()
prompt = ChatPromptTemplate.from_messages([
("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan."),
("user", "{message}")
])
if hasattr(llm, "with_structured_output"):
chain = prompt | llm.with_structured_output(IntentResponse)
else:
from langchain.output_parsers import PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=IntentResponse)
prompt = ChatPromptTemplate.from_messages([
("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan.\n\n{format_instructions}"),
("user", "{message}")
])
chain = prompt | llm | parser
history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
if hasattr(llm, "with_structured_output"):
response = chain.invoke({"message": context_message})
else:
response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
return {"detected_intent": response.intent}
def handle_greeting(state: AgentState) -> AgentState:
return {"response": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"}
def handle_unknown(state: AgentState) -> AgentState:
return {"response": "I'm not quite sure how to help with that. Could you clarify your question about AutoStream?"}
def retrieve_knowledge(state: AgentState) -> AgentState:
docs = retrieve_documents(state["current_message"])
return {"retrieved_documents": docs}
def generate_rag_response(state: AgentState) -> AgentState:
llm = get_llm()
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful sales assistant for AutoStream. Answer the user's question based strictly on the following retrieved knowledge:\n\n{context}\n\nIf the answer is not in the context, say you don't know."),
("user", "{message}")
])
context = "\n\n".join(state.get("retrieved_documents", []))
chain = prompt | llm
response = chain.invoke({
"context": context,
"message": state["current_message"]
})
content = response.content if hasattr(response, "content") else str(response)
return {"response": content}
def process_lead(state: AgentState) -> AgentState:
llm = get_llm()
extract_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
("user", "{message}")
])
if hasattr(llm, "with_structured_output"):
extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
else:
from langchain.output_parsers import PydanticOutputParser
parser = PydanticOutputParser(pydantic_object=LeadExtractionResponse)
extract_prompt = ChatPromptTemplate.from_messages([
("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found.\n\n{format_instructions}"),
("user", "{message}")
])
extract_chain = extract_prompt | llm | parser
history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
if hasattr(llm, "with_structured_output"):
extracted = extract_chain.invoke({"message": context_message})
else:
extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
updates = {}
if extracted.user_name and not state.get("user_name"):
updates["user_name"] = extracted.user_name
if extracted.user_email and not state.get("user_email"):
updates["user_email"] = extracted.user_email
if extracted.creator_platform and not state.get("creator_platform"):
updates["creator_platform"] = extracted.creator_platform
current_name = updates.get("user_name", state.get("user_name"))
current_email = updates.get("user_email", state.get("user_email"))
current_platform = updates.get("creator_platform", state.get("creator_platform"))
if not current_name:
updates["response"] = "Great! I can help with that. Could I have your name?"
return updates
elif not current_email:
updates["response"] = f"Thanks {current_name}! What is your email address?"
return updates
elif not current_platform:
updates["response"] = "Got it. And what creator platform do you primarily use (e.g., YouTube, TikTok)?"
return updates
else:
updates["lead_ready"] = True
return updates
def execute_tool(state: AgentState) -> AgentState:
if state.get("lead_ready") and state.get("user_name") and state.get("user_email") and state.get("creator_platform"):
mock_lead_capture(
state["user_name"],
state["user_email"],
state["creator_platform"]
)
return {"response": f"Thanks {state['user_name']}! I've successfully collected your information for your {state['creator_platform']} channel. Our team will reach out to {state['user_email']} shortly."}
else:
return {"response": "Error: Tried to execute lead capture tool without all required fields."}