Priyansh Saxena commited on
Commit
4c68cfa
·
1 Parent(s): 9c92c35

feat: add local llm

Browse files

Signed-off-by: Priyansh Saxena <priyena.programming@gmail.com>

Files changed (3) hide show
  1. agent/nodes.py +0 -25
  2. app.py +25 -21
  3. rag/embeddings.py +5 -3
agent/nodes.py CHANGED
@@ -5,7 +5,6 @@ from langchain_core.prompts import ChatPromptTemplate
5
  from agent.state import AgentState
6
  from rag.retriever import retrieve_documents
7
  from tools.lead_capture import mock_lead_capture
8
- <<<<<<< HEAD
9
  from langchain_huggingface import HuggingFacePipeline
10
  from transformers import pipeline
11
  import os
@@ -27,10 +26,6 @@ def get_llm():
27
  )
28
  _local_llm = HuggingFacePipeline(pipeline=pipe)
29
  return _local_llm
30
- =======
31
-
32
- def get_llm():
33
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
34
  return ChatOpenAI(model="gpt-4o-mini", temperature=0)
35
 
36
  class IntentResponse(BaseModel):
@@ -49,7 +44,6 @@ def detect_intent(state: AgentState) -> AgentState:
49
  ("user", "{message}")
50
  ])
51
 
52
- <<<<<<< HEAD
53
  if hasattr(llm, "with_structured_output"):
54
  chain = prompt | llm.with_structured_output(IntentResponse)
55
  else:
@@ -62,21 +56,14 @@ def detect_intent(state: AgentState) -> AgentState:
62
  ])
63
 
64
  chain = prompt | llm | parser
65
- =======
66
- chain = prompt | llm.with_structured_output(IntentResponse)
67
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
68
 
69
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
70
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
71
 
72
- <<<<<<< HEAD
73
  if hasattr(llm, "with_structured_output"):
74
  response = chain.invoke({"message": context_message})
75
  else:
76
  response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
77
- =======
78
- response = chain.invoke({"message": context_message})
79
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
80
 
81
  return {"detected_intent": response.intent}
82
 
@@ -105,13 +92,9 @@ def generate_rag_response(state: AgentState) -> AgentState:
105
  "message": state["current_message"]
106
  })
107
 
108
- <<<<<<< HEAD
109
  content = response.content if hasattr(response, "content") else str(response)
110
 
111
  return {"response": content}
112
- =======
113
- return {"response": response.content}
114
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
115
 
116
  def process_lead(state: AgentState) -> AgentState:
117
  llm = get_llm()
@@ -120,7 +103,6 @@ def process_lead(state: AgentState) -> AgentState:
120
  ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
121
  ("user", "{message}")
122
  ])
123
- <<<<<<< HEAD
124
 
125
  if hasattr(llm, "with_structured_output"):
126
  extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
@@ -132,21 +114,14 @@ def process_lead(state: AgentState) -> AgentState:
132
  ("user", "{message}")
133
  ])
134
  extract_chain = extract_prompt | llm | parser
135
- =======
136
- extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
137
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
138
 
139
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
140
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
141
 
142
- <<<<<<< HEAD
143
  if hasattr(llm, "with_structured_output"):
144
  extracted = extract_chain.invoke({"message": context_message})
145
  else:
146
  extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
147
- =======
148
- extracted = extract_chain.invoke({"message": context_message})
149
- >>>>>>> 128a106 (Migration Phase 6: Autonomous agent core modules)
150
 
151
  updates = {}
152
  if extracted.user_name and not state.get("user_name"):
 
5
  from agent.state import AgentState
6
  from rag.retriever import retrieve_documents
7
  from tools.lead_capture import mock_lead_capture
 
8
  from langchain_huggingface import HuggingFacePipeline
9
  from transformers import pipeline
10
  import os
 
26
  )
27
  _local_llm = HuggingFacePipeline(pipeline=pipe)
28
  return _local_llm
 
 
 
 
29
  return ChatOpenAI(model="gpt-4o-mini", temperature=0)
30
 
31
  class IntentResponse(BaseModel):
 
44
  ("user", "{message}")
45
  ])
46
 
 
47
  if hasattr(llm, "with_structured_output"):
48
  chain = prompt | llm.with_structured_output(IntentResponse)
49
  else:
 
56
  ])
57
 
58
  chain = prompt | llm | parser
 
 
 
59
 
60
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
61
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
62
 
 
63
  if hasattr(llm, "with_structured_output"):
64
  response = chain.invoke({"message": context_message})
65
  else:
66
  response = chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
 
 
 
67
 
68
  return {"detected_intent": response.intent}
69
 
 
92
  "message": state["current_message"]
93
  })
94
 
 
95
  content = response.content if hasattr(response, "content") else str(response)
96
 
97
  return {"response": content}
 
 
 
98
 
99
  def process_lead(state: AgentState) -> AgentState:
100
  llm = get_llm()
 
103
  ("system", "Extract the user's name, email, and creator platform (e.g. YouTube, TikTok, Instagram) from the message if present. Return null for fields not found."),
104
  ("user", "{message}")
105
  ])
 
106
 
107
  if hasattr(llm, "with_structured_output"):
108
  extract_chain = extract_prompt | llm.with_structured_output(LeadExtractionResponse)
 
114
  ("user", "{message}")
115
  ])
116
  extract_chain = extract_prompt | llm | parser
 
 
 
117
 
118
  history_str = "\n".join([f"{msg['role']}: {msg['content']}" for msg in state.get("conversation_history", [])[-3:]])
119
  context_message = f"Recent history:\n{history_str}\n\nCurrent message:\n{state['current_message']}"
120
 
 
121
  if hasattr(llm, "with_structured_output"):
122
  extracted = extract_chain.invoke({"message": context_message})
123
  else:
124
  extracted = extract_chain.invoke({"message": context_message, "format_instructions": parser.get_format_instructions()})
 
 
 
125
 
126
  updates = {}
127
  if extracted.user_name and not state.get("user_name"):
app.py CHANGED
@@ -4,22 +4,38 @@ from dotenv import load_dotenv
4
  from agent.graph import app
5
  from agent.state import AgentState
6
 
7
-
8
  load_dotenv()
9
 
 
10
 
11
- st.set_page_config(page_title="AutoStream AI Sales Assistant", page_icon="🤖", layout="centered")
12
-
13
-
14
  st.markdown("""
15
  <style>
16
  .stChatFloatingInputContainer {
17
  bottom: 20px;
18
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  </style>
20
  """, unsafe_allow_html=True)
21
 
22
-
23
  if "messages" not in st.session_state:
24
  st.session_state.messages = []
25
 
@@ -35,54 +51,43 @@ if "messages" not in st.session_state:
35
  response=""
36
  )
37
 
38
-
39
  st.session_state.messages.append({"role": "assistant", "content": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"})
40
 
 
 
41
 
42
  if not os.environ.get("OPENAI_API_KEY"):
43
- st.warning("️ OPENAI_API_KEY is not set. Please set it in your environment to use the agent.")
44
-
45
- st.title("🤖 AutoStream AI Sales Assistant")
46
- st.markdown("Ask me about AutoStream features and pricing, or sign up for a plan!")
47
-
48
 
49
  for message in st.session_state.messages:
50
  with st.chat_message(message["role"]):
51
  st.markdown(message["content"])
52
 
53
-
54
  if prompt := st.chat_input("What would you like to know?"):
55
 
56
  st.session_state.messages.append({"role": "user", "content": prompt})
57
  with st.chat_message("user"):
58
  st.markdown(prompt)
59
 
60
-
61
  st.session_state.agent_state["current_message"] = prompt
62
 
63
-
64
  with st.chat_message("assistant"):
65
  with st.spinner("Thinking..."):
66
  try:
67
-
68
  result_state = app.invoke(st.session_state.agent_state)
69
  st.session_state.agent_state = result_state
70
 
71
  response = result_state["response"]
72
 
73
-
74
  st.session_state.agent_state["conversation_history"].append({"role": "user", "content": prompt})
75
  st.session_state.agent_state["conversation_history"].append({"role": "assistant", "content": response})
76
 
77
-
78
  if len(st.session_state.agent_state["conversation_history"]) > 12:
79
  st.session_state.agent_state["conversation_history"] = st.session_state.agent_state["conversation_history"][-12:]
80
 
81
-
82
  st.markdown(response)
83
 
84
-
85
- with st.expander("Agent Reasoning & State"):
86
  st.write(f"**Detected Intent:** `{result_state.get('detected_intent', 'UNKNOWN')}`")
87
  if result_state.get("retrieved_documents") and result_state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
88
  st.write(f"**RAG Retrieval:** Found {len(result_state['retrieved_documents'])} relevant knowledge chunks.")
@@ -99,5 +104,4 @@ if prompt := st.chat_input("What would you like to know?"):
99
  response = f"An error occurred: {str(e)}"
100
  st.error(response)
101
 
102
-
103
  st.session_state.messages.append({"role": "assistant", "content": response})
 
4
  from agent.graph import app
5
  from agent.state import AgentState
6
 
 
7
  load_dotenv()
8
 
9
+ st.set_page_config(page_title="AutoStream AI Sales Assistant", page_icon="🎬", layout="centered")
10
 
11
+ # Custom CSS for minimalist, cooler UI
 
 
12
  st.markdown("""
13
  <style>
14
  .stChatFloatingInputContainer {
15
  bottom: 20px;
16
  }
17
+ .main {
18
+ background-color: #0E1117;
19
+ }
20
+ h1 {
21
+ color: #E2E8F0;
22
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
23
+ font-weight: 700;
24
+ text-align: center;
25
+ margin-bottom: 2rem;
26
+ }
27
+ .subtitle {
28
+ color: #94A3B8;
29
+ text-align: center;
30
+ margin-bottom: 2rem;
31
+ font-size: 1.1rem;
32
+ }
33
+ .stAlert {
34
+ border-radius: 8px;
35
+ }
36
  </style>
37
  """, unsafe_allow_html=True)
38
 
 
39
  if "messages" not in st.session_state:
40
  st.session_state.messages = []
41
 
 
51
  response=""
52
  )
53
 
 
54
  st.session_state.messages.append({"role": "assistant", "content": "Hello! I'm the AutoStream assistant. I can answer questions about our features and pricing. How can I help you today?"})
55
 
56
+ st.markdown("<h1>🎬 AutoStream Assistant</h1>", unsafe_allow_html=True)
57
+ st.markdown("<div class='subtitle'>Ask about features and pricing, or sign up for a plan instantly!</div>", unsafe_allow_html=True)
58
 
59
  if not os.environ.get("OPENAI_API_KEY"):
60
+ st.info("️ OPENAI_API_KEY is not set. The system will fall back to a local Qwen model and HuggingFace embeddings.")
 
 
 
 
61
 
62
  for message in st.session_state.messages:
63
  with st.chat_message(message["role"]):
64
  st.markdown(message["content"])
65
 
 
66
  if prompt := st.chat_input("What would you like to know?"):
67
 
68
  st.session_state.messages.append({"role": "user", "content": prompt})
69
  with st.chat_message("user"):
70
  st.markdown(prompt)
71
 
 
72
  st.session_state.agent_state["current_message"] = prompt
73
 
 
74
  with st.chat_message("assistant"):
75
  with st.spinner("Thinking..."):
76
  try:
 
77
  result_state = app.invoke(st.session_state.agent_state)
78
  st.session_state.agent_state = result_state
79
 
80
  response = result_state["response"]
81
 
 
82
  st.session_state.agent_state["conversation_history"].append({"role": "user", "content": prompt})
83
  st.session_state.agent_state["conversation_history"].append({"role": "assistant", "content": response})
84
 
 
85
  if len(st.session_state.agent_state["conversation_history"]) > 12:
86
  st.session_state.agent_state["conversation_history"] = st.session_state.agent_state["conversation_history"][-12:]
87
 
 
88
  st.markdown(response)
89
 
90
+ with st.expander("Agent Reasoning & State", expanded=False):
 
91
  st.write(f"**Detected Intent:** `{result_state.get('detected_intent', 'UNKNOWN')}`")
92
  if result_state.get("retrieved_documents") and result_state.get("detected_intent") in ["PRODUCT_QUERY", "PRICING_QUERY"]:
93
  st.write(f"**RAG Retrieval:** Found {len(result_state['retrieved_documents'])} relevant knowledge chunks.")
 
104
  response = f"An error occurred: {str(e)}"
105
  st.error(response)
106
 
 
107
  st.session_state.messages.append({"role": "assistant", "content": response})
rag/embeddings.py CHANGED
@@ -1,7 +1,9 @@
1
  from langchain_openai import OpenAIEmbeddings
 
 
2
 
3
  def get_embeddings():
4
- """
5
- Returns the embedding model used for the RAG pipeline.
6
- """
7
  return OpenAIEmbeddings()
 
1
  from langchain_openai import OpenAIEmbeddings
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ import os
4
 
5
  def get_embeddings():
6
+
7
+ if not os.environ.get("OPENAI_API_KEY"):
8
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
9
  return OpenAIEmbeddings()