Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| from pydantic import BaseModel, Field | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from agent.state import AgentState | |
| from rag.retriever import retrieve_documents | |
| from tools.lead_capture import mock_lead_capture | |
| from langchain_huggingface import HuggingFacePipeline | |
| from transformers import pipeline | |
| import os | |
| _local_llm = None | |
| def get_llm(): | |
| global _local_llm | |
| if not os.environ.get("OPENAI_API_KEY"): | |
| if _local_llm is None: | |
| pipe = pipeline( | |
| "text-generation", | |
| model="Qwen/Qwen2.5-0.5B-Instruct", | |
| max_new_tokens=512, | |
| device="cpu", | |
| trust_remote_code=True, | |
| return_full_text=False | |
| ) | |
| _local_llm = HuggingFacePipeline(pipeline=pipe) | |
| return _local_llm | |
| return ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| class IntentResponse(BaseModel): | |
| intent: str = Field(description="The intent of the user. Must be one of: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN") | |
| confidence: float = Field(description="Confidence score between 0 and 1") | |
| class LeadExtractionResponse(BaseModel): | |
| user_name: Optional[str] = Field(default=None, description="The name of the user if provided") | |
| user_email: Optional[str] = Field(default=None, description="The email address of the user if provided") | |
| creator_platform: Optional[str] = Field(default=None, description="The creator platform (e.g., YouTube, Instagram) if provided") | |
| def detect_intent(state: AgentState) -> AgentState: | |
| llm = get_llm() | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan."), | |
| ("user", "{message}") | |
| ]) | |
| if hasattr(llm, "with_structured_output"): | |
| chain = prompt | llm.with_structured_output(IntentResponse) | |
| else: | |
| from langchain.output_parsers import PydanticOutputParser | |
| parser = PydanticOutputParser(pydantic_object=IntentResponse) | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan.\n\n{format_instructions}"), | |
| ("user", "{message}") | |
| ]) | |
| chain = prompt | llm | parser | |
| history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]]) | |
| context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}" | |
| if hasattr(llm, "with_structured_output"): | |
| response = chain.invoke({"message": context_message}) | |
| else: | |
| response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()}) | |
| return {"detected_intent": response.intent} | |
| def handle_greeting(state: AgentState) -> AgentState: | |
| return {"response": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"} | |
| def handle_unknown(state: AgentState) -> AgentState: | |
| return {"response": "I'm not quite sure how to help with that. Could you clarify your question about AutoStream?"} | |
| def retrieve_knowledge(state: AgentState) -> AgentState: | |
| docs = retrieve_documents(state["current_message"]) | |
| return {"retrieved_documents": docs} | |
| def generate_rag_response(state: AgentState) -> AgentState: | |
| llm = get_llm() | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "You are a helpful sales assistant for AutoStream. Answer the user's question based strictly on the following retrieved knowledge:\n\n{context}\n\nIf the answer is not in the context, say you don't know."), | |
| ("user", "{message}") | |
| ]) | |
| context = "\n\n".join(state.get("retrieved_documents", [])) | |
| chain = prompt | llm | |
| response = chain.invoke({ | |
| "context": context, | |
| "message": state["current_message"] | |
| }) | |
| content = response.content if hasattr(response, "content") else str(response) | |
| return {"response": content} | |
| def process_lead(state: AgentState) -> AgentState: | |
| llm = get_llm() | |
| extract_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."), | |
| ("user", "{message}") | |
| ]) | |
| if hasattr(llm, "with_structured_output"): | |
| extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse) | |
| else: | |
| from langchain.output_parsers import PydanticOutputParser | |
| parser = PydanticOutputParser(pydantic_object=LeadExtractionResponse) | |
| extract_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found.\n\n{format_instructions}"), | |
| ("user", "{message}") | |
| ]) | |
| extract_chain = extract_prompt | llm | parser | |
| history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]]) | |
| context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}" | |
| if hasattr(llm, "with_structured_output"): | |
| extracted = extract_chain.invoke({"message": context_message}) | |
| else: | |
| extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()}) | |
| updates = {} | |
| if extracted.user_name and not state.get("user_name"): | |
| updates["user_name"] = extracted.user_name | |
| if extracted.user_email and not state.get("user_email"): | |
| updates["user_email"] = extracted.user_email | |
| if extracted.creator_platform and not state.get("creator_platform"): | |
| updates["creator_platform"] = extracted.creator_platform | |
| current_name = updates.get("user_name", state.get("user_name")) | |
| current_email = updates.get("user_email", state.get("user_email")) | |
| current_platform = updates.get("creator_platform", state.get("creator_platform")) | |
| if not current_name: | |
| updates["response"] = "Great! I can help with that. Could I have your name?" | |
| return updates | |
| elif not current_email: | |
| updates["response"] = f"Thanks {current_name}! What is your email address?" | |
| return updates | |
| elif not current_platform: | |
| updates["response"] = "Got it. And what creator platform do you primarily use (e.g., YouTube, TikTok)?" | |
| return updates | |
| else: | |
| updates["lead_ready"] = True | |
| return updates | |
| def execute_tool(state: AgentState) -> AgentState: | |
| if state.get("lead_ready") and state.get("user_name") and state.get("user_email") and state.get("creator_platform"): | |
| mock_lead_capture( | |
| state["user_name"], | |
| state["user_email"], | |
| state["creator_platform"] | |
| ) | |
| return {"response": f"Thanks {state['user_name']}! I've successfully collected your information for your {state['creator_platform']} channel. Our team will reach out to {state['user_email']} shortly."} | |
| else: | |
| return {"response": "Error: Tried to execute lead capture tool without all required fields."} | |