google-labs-jules[bot] archc0der commited on
Commit
9d2e886
·
1 Parent(s): c987214

feat: integrate local HF fallback models and enhance Streamlit UI

Browse files

- Adds support for a local HuggingFace Qwen model (0.5B-Instruct) when OPENAI_API_KEY is not provided
- Adds support for local HuggingFace embeddings (sentence-transformers)
- Enhances Streamlit UI with a minimalist, modern dark theme
- Updates requirements to include transformers, pytorch, and HF dependencies

Co-authored-by: archc0der <119496494+archc0der@users.noreply.github.com>

Files changed (4) hide show
  1. agent/nodes.py +53 -5
  2. app.py +25 -21
  3. rag/embeddings.py +5 -3
  4. requirements.txt +5 -0
agent/nodes.py CHANGED
@@ -5,8 +5,27 @@ from langchain_core.prompts import ChatPromptTemplate
5
  from agent.state import AgentState
6
  from rag.retriever import retrieve_documents
7
  from tools.lead_capture import mock_lead_capture
 
 
 
 
 
8
 
9
  def get_llm():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return ChatOpenAI(model="gpt-4o-mini", temperature=0)
11
 
12
  class IntentResponse(BaseModel):
@@ -25,12 +44,26 @@ def detect_intent(state: AgentState) -> AgentState:
25
  ("user", "{message}")
26
  ])
27
 
28
- chain = prompt | llm.with_structured_output(IntentResponse)
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
31
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
32
 
33
- response = chain.invoke({"message": context_message})
 
 
 
34
 
35
  return {"detected_intent": response.intent}
36
 
@@ -59,7 +92,9 @@ def generate_rag_response(state: AgentState) -> AgentState:
59
  "message": state["current_message"]
60
  })
61
 
62
- return {"response": response.content}
 
 
63
 
64
  def process_lead(state: AgentState) -> AgentState:
65
  llm = get_llm()
@@ -68,12 +103,25 @@ def process_lead(state: AgentState) -> AgentState:
68
  ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
69
  ("user", "{message}")
70
  ])
71
- extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
 
 
 
 
 
 
 
 
 
 
72
 
73
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
74
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
75
 
76
- extracted = extract_chain.invoke({"message": context_message})
 
 
 
77
 
78
  updates = {}
79
  if extracted.user_name and not state.get("user_name"):
 
5
  from agent.state import AgentState
6
  from rag.retriever import retrieve_documents
7
  from tools.lead_capture import mock_lead_capture
8
+ from langchain_huggingface import HuggingFacePipeline
9
+ from transformers import pipeline
10
+ import os
11
+
12
+ _local_llm = None
13
 
14
  def get_llm():
15
+ global _local_llm
16
+ if not os.environ.get("OPENAI_API_KEY"):
17
+ if _local_llm is None:
18
+
19
+ pipe = pipeline(
20
+ "text-generation",
21
+ model="Qwen/Qwen2.5-0.5B-Instruct",
22
+ max_new_tokens=512,
23
+ device="cpu",
24
+ trust_remote_code=True,
25
+ return_full_text=False
26
+ )
27
+ _local_llm = HuggingFacePipeline(pipeline=pipe)
28
+ return _local_llm
29
  return ChatOpenAI(model="gpt-4o-mini", temperature=0)
30
 
31
  class IntentResponse(BaseModel):
 
44
  ("user", "{message}")
45
  ])
46
 
47
+ if hasattr(llm, "with_structured_output"):
48
+ chain = prompt | llm.with_structured_output(IntentResponse)
49
+ else:
50
+ from langchain.output_parsers import PydanticOutputParser
51
+ parser = PydanticOutputParser(pydantic_object=IntentResponse)
52
+
53
+ prompt = ChatPromptTemplate.from_messages([
54
+ ("system", "You are an intent classification assistant for AutoStream. Analyze the user's message and determine the intent. Categories: GREETING, PRODUCT_QUERY, PRICING_QUERY, HIGH_INTENT_LEAD, UNKNOWN. A 'HIGH_INTENT_LEAD' is when a user explicitly expresses interest in signing up, buying, or trying out a plan.\n\n{format_instructions}"),
55
+ ("user", "{message}")
56
+ ])
57
+
58
+ chain = prompt | llm | parser
59
 
60
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
61
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
62
 
63
+ if hasattr(llm, "with_structured_output"):
64
+ response = chain.invoke({"message": context_message})
65
+ else:
66
+ response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
67
 
68
  return {"detected_intent": response.intent}
69
 
 
92
  "message": state["current_message"]
93
  })
94
 
95
+ content = response.content if hasattr(response, "content") else str(response)
96
+
97
+ return {"response": content}
98
 
99
  def process_lead(state: AgentState) -> AgentState:
100
  llm = get_llm()
 
103
  ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
104
  ("user", "{message}")
105
  ])
106
+
107
+ if hasattr(llm, "with_structured_output"):
108
+ extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
109
+ else:
110
+ from langchain.output_parsers import PydanticOutputParser
111
+ parser = PydanticOutputParser(pydantic_object=LeadExtractionResponse)
112
+ extract_prompt = ChatPromptTemplate.from_messages([
113
+ ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found.\n\n{format_instructions}"),
114
+ ("user", "{message}")
115
+ ])
116
+ extract_chain = extract_prompt | llm | parser
117
 
118
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
119
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
120
 
121
+ if hasattr(llm, "with_structured_output"):
122
+ extracted = extract_chain.invoke({"message": context_message})
123
+ else:
124
+ extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
125
 
126
  updates = {}
127
  if extracted.user_name and not state.get("user_name"):
app.py CHANGED
@@ -4,22 +4,38 @@ from dotenv import load_dotenv
4
  from agent.graph import app
5
  from agent.state import AgentState
6
 
7
-
8
  load_dotenv()
9
 
 
10
 
11
- st.set_page_config(page_title="AutoStream AI Sales Assistant", page_icon="🤖", layout="centered")
12
-
13
-
14
  st.markdown("""
15
  <style>
16
  .stChatFloatingInputContainer {
17
  bottom: 20px;
18
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  </style>
20
  """, unsafe_allow_html=True)
21
 
22
-
23
  if "messages" not in st.session_state:
24
  st.session_state.messages = []
25
 
@@ -35,54 +51,43 @@ if "messages" not in st.session_state:
35
  response=""
36
  )
37
 
38
-
39
  st.session_state.messages.append({"role": "assistant", "content": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"})
40
 
 
 
41
 
42
  if not os.environ.get("OPENAI_API_KEY"):
43
- st.warning("️ OPENAI_API_KEY is not set. Please set it in your environment to use the agent.")
44
-
45
- st.title("🤖 AutoStream AI Sales Assistant")
46
- st.markdown("Ask me about AutoStream features and pricing, or sign up for a plan!")
47
-
48
 
49
  for message in st.session_state.messages:
50
  with st.chat_message(message["role"]):
51
  st.markdown(message["content"])
52
 
53
-
54
  if prompt := st.chat_input("What would you like to know?"):
55
 
56
  st.session_state.messages.append({"role": "user", "content": prompt})
57
  with st.chat_message("user"):
58
  st.markdown(prompt)
59
 
60
-
61
  st.session_state.agent_state["current_message"] = prompt
62
 
63
-
64
  with st.chat_message("assistant"):
65
  with st.spinner("Thinking..."):
66
  try:
67
-
68
  result_state = app.invoke(st.session_state.agent_state)
69
  st.session_state.agent_state = result_state
70
 
71
  response = result_state["response"]
72
 
73
-
74
  st.session_state.agent_state["conversation_history"].append({"role": "user", "content": prompt})
75
  st.session_state.agent_state["conversation_history"].append({"role": "assistant", "content": response})
76
 
77
-
78
  if len(st.session_state.agent_state["conversation_history"]) > 12:
79
  st.session_state.agent_state["conversation_history"] = st.session_state.agent_state["conversation_history"][-12:]
80
 
81
-
82
  st.markdown(response)
83
 
84
-
85
- with st.expander("Agent Reasoning & State"):
86
  st.write(f"**Detected Intent:** `{result_state.get('detected_intent', 'UNKNOWN')}`")
87
  if result_state.get("retrieved_documents") and result_state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
88
  st.write(f"**RAG Retrieval:** Found {len(result_state['retrieved_documents'])} relevant knowledge chunks.")
@@ -99,5 +104,4 @@ if prompt := st.chat_input("What would you like to know?"):
99
  response = f"An error occurred: {str(e)}"
100
  st.error(response)
101
 
102
-
103
  st.session_state.messages.append({"role": "assistant", "content": response})
 
4
  from agent.graph import app
5
  from agent.state import AgentState
6
 
 
7
  load_dotenv()
8
 
9
+ st.set_page_config(page_title="AutoStream AI Sales Assistant", page_icon="🎬", layout="centered")
10
 
11
+ # Custom CSS for minimalist, cooler UI
 
 
12
  st.markdown("""
13
  <style>
14
  .stChatFloatingInputContainer {
15
  bottom: 20px;
16
  }
17
+ .main {
18
+ background-color: #0E1117;
19
+ }
20
+ h1 {
21
+ color: #E2E8F0;
22
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
23
+ font-weight: 700;
24
+ text-align: center;
25
+ margin-bottom: 2rem;
26
+ }
27
+ .subtitle {
28
+ color: #94A3B8;
29
+ text-align: center;
30
+ margin-bottom: 2rem;
31
+ font-size: 1.1rem;
32
+ }
33
+ .stAlert {
34
+ border-radius: 8px;
35
+ }
36
  </style>
37
  """, unsafe_allow_html=True)
38
 
 
39
  if "messages" not in st.session_state:
40
  st.session_state.messages = []
41
 
 
51
  response=""
52
  )
53
 
 
54
  st.session_state.messages.append({"role": "assistant", "content": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"})
55
 
56
+ st.markdown("<h1>🎬 AutoStream Assistant</h1>", unsafe_allow_html=True)
57
+ st.markdown("<div class='subtitle'>Ask about features and pricing, or sign up for a plan instantly!</div>", unsafe_allow_html=True)
58
 
59
  if not os.environ.get("OPENAI_API_KEY"):
60
+ st.info("️ OPENAI_API_KEY is not set. The system will fall back to a local Qwen model and HuggingFace embeddings.")
 
 
 
 
61
 
62
  for message in st.session_state.messages:
63
  with st.chat_message(message["role"]):
64
  st.markdown(message["content"])
65
 
 
66
  if prompt := st.chat_input("What would you like to know?"):
67
 
68
  st.session_state.messages.append({"role": "user", "content": prompt})
69
  with st.chat_message("user"):
70
  st.markdown(prompt)
71
 
 
72
  st.session_state.agent_state["current_message"] = prompt
73
 
 
74
  with st.chat_message("assistant"):
75
  with st.spinner("Thinking..."):
76
  try:
 
77
  result_state = app.invoke(st.session_state.agent_state)
78
  st.session_state.agent_state = result_state
79
 
80
  response = result_state["response"]
81
 
 
82
  st.session_state.agent_state["conversation_history"].append({"role": "user", "content": prompt})
83
  st.session_state.agent_state["conversation_history"].append({"role": "assistant", "content": response})
84
 
 
85
  if len(st.session_state.agent_state["conversation_history"]) > 12:
86
  st.session_state.agent_state["conversation_history"] = st.session_state.agent_state["conversation_history"][-12:]
87
 
 
88
  st.markdown(response)
89
 
90
+ with st.expander("Agent Reasoning & State", expanded=False):
 
91
  st.write(f"**Detected Intent:** `{result_state.get('detected_intent', 'UNKNOWN')}`")
92
  if result_state.get("retrieved_documents") and result_state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
93
  st.write(f"**RAG Retrieval:** Found {len(result_state['retrieved_documents'])} relevant knowledge chunks.")
 
104
  response = f"An error occurred: {str(e)}"
105
  st.error(response)
106
 
 
107
  st.session_state.messages.append({"role": "assistant", "content": response})
rag/embeddings.py CHANGED
@@ -1,7 +1,9 @@
1
  from langchain_openai import OpenAIEmbeddings
 
 
2
 
3
  def get_embeddings():
4
- """
5
- Returns the embedding model used for the RAG pipeline.
6
- """
7
  return OpenAIEmbeddings()
 
1
  from langchain_openai import OpenAIEmbeddings
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ import os
4
 
5
  def get_embeddings():
6
+
7
+ if not os.environ.get("OPENAI_API_KEY"):
8
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
9
  return OpenAIEmbeddings()
requirements.txt CHANGED
@@ -9,3 +9,8 @@ pydantic
9
  pytest
10
  pytest-mock
11
  streamlit
 
 
 
 
 
 
9
  pytest
10
  pytest-mock
11
  streamlit
12
+ transformers
13
+ langchain-huggingface
14
+ huggingface-hub
15
+ sentence-transformers
16
+ torch