genaitiwari commited on
Commit
3392ab1
Β·
1 Parent(s): ac85c1d

Added cache seed and CAG in autogen usecase

Browse files
.gitignore CHANGED
@@ -6,3 +6,4 @@ codegen/tmp_code_3e1806a0bf22b99c6c5d2b77650fe9a8.py
6
  /codegen
7
  /tmp/chromadb
8
  /tmp/db
 
 
6
  /codegen
7
  /tmp/chromadb
8
  /tmp/db
9
+ /.cache
README.md CHANGED
@@ -55,6 +55,16 @@ Requirements
55
  #### Basic Example
56
  ![alt text](basic_example.png)
57
 
 
 
 
 
 
 
 
 
 
 
58
  #### MultiAgent Chat
59
  prompt : As a user , create a asp.net form with razor view page for health insaurance feedback page
60
  ![alt text](multiagent_chat.png)
 
55
  #### Basic Example
56
  ![alt text](basic_example.png)
57
 
58
+ ### Chat with CAG
59
+ prompt1: what is dotnet
60
+ prompt2: what is python
61
+ prompt3: what is python
62
+ prompt4: what is dotnet
63
+ prompt5: what is python
64
+
65
+ ![alt text](cag_chat.png)
66
+
67
+
68
  #### MultiAgent Chat
69
  prompt : As a user , create a asp.net form with razor view page for health insaurance feedback page
70
  ![alt text](multiagent_chat.png)
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
 
 
3
  from configfile import Config
4
  from src.streamlitui.loadui import LoadStreamlitUI
5
  from src.usecases.multiagentschat import MultiAgentChat
@@ -9,6 +10,7 @@ from src.usecases.agentchatsqlspider import AgentChatSqlSpider
9
  from src.LLMS.groqllm import GroqLLM
10
  from src.usecases.multiagentragchat import MultiAgentRAGChat
11
  from src.usecases.basicexample import BasicExample
 
12
 
13
 
14
  # MAIN Function START
@@ -67,5 +69,15 @@ if __name__ == "__main__":
67
  problem=problem)
68
  obj_basic_example.run()
69
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
 
1
  import streamlit as st
2
 
3
+ from src.cag.main import CAGLLM
4
  from configfile import Config
5
  from src.streamlitui.loadui import LoadStreamlitUI
6
  from src.usecases.multiagentschat import MultiAgentChat
 
10
  from src.LLMS.groqllm import GroqLLM
11
  from src.usecases.multiagentragchat import MultiAgentRAGChat
12
  from src.usecases.basicexample import BasicExample
13
+ from src.usecases.cag_chat import CAGLLMChat
14
 
15
 
16
  # MAIN Function START
 
69
  problem=problem)
70
  obj_basic_example.run()
71
 
72
+ elif user_input['selected_usecase'] == "Chat with CAG":
73
+
74
+ obj_chat = CAGLLMChat(llm_config=llm_config,problem=problem)
75
+ response = obj_chat.start_chat()
76
+
77
+ obj_cag_llm = CAGLLM(problem,response)
78
+
79
+ obj_cag_llm.process_cag_llm()
80
+
81
+
82
 
83
 
cag_chat.png ADDED
configfile.ini CHANGED
@@ -1,6 +1,6 @@
1
  [DEFAULT]
2
  PAGE_TITLE = AUTOGEN IN ACTION
3
  LLM_OPTIONS = Groq, Huggingface
4
- USECASE_OPTIONS = Basic Example, MultiAgent Chat, MultiAgent Code Execution, RAG Chat, With LLamaIndex Tool
5
  GROQ_MODEL_OPTIONS = llama-3.3-70b-versatile, mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma2-9b-it
6
 
 
1
  [DEFAULT]
2
  PAGE_TITLE = AUTOGEN IN ACTION
3
  LLM_OPTIONS = Groq, Huggingface
4
+ USECASE_OPTIONS = Basic Example, Chat with CAG, MultiAgent Chat, MultiAgent Code Execution, RAG Chat, With LLamaIndex Tool
5
  GROQ_MODEL_OPTIONS = llama-3.3-70b-versatile, mixtral-8x7b-32768, llama3-8b-8192, llama3-70b-8192, gemma2-9b-it
6
 
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  streamlit
2
- pyautogen
3
- groq
4
  llama-index
5
  llama-index-tools-wikipedia
6
  llama-index-readers-wikipedia
@@ -10,4 +10,19 @@ spider-env
10
  pyautogen[retrievechat]
11
  pyautogen[retrievechat-qdrant]
12
  flaml[automl]
13
- sentence_transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  streamlit
2
+ pyautogen==0.2.32
3
+ groq==0.9.0
4
  llama-index
5
  llama-index-tools-wikipedia
6
  llama-index-readers-wikipedia
 
10
  pyautogen[retrievechat]
11
  pyautogen[retrievechat-qdrant]
12
  flaml[automl]
13
+ # Core Libraries
14
+ python-dotenv
15
+ streamlit>=1.30.0
16
+ numpy>=1.24.0
17
+ scikit-learn>=1.2.2
18
+ plotly>=5.17.0
19
+ pandas>=2.0.0
20
+ requests
21
+
22
+ # Streamlit Extensions for Enhanced UI
23
+ streamlit-extras>=0.2.0
24
+
25
+ # LLM Integration and Embedding Tools
26
+ torch>=2.0.0
27
+ transformers>=4.35.0
28
+ sentence-transformers>=2.2.2
src/LLMS/groqllm.py CHANGED
@@ -9,14 +9,25 @@ class GroqLLM:
9
  self.user_controls_input = user_controls_input
10
 
11
  def groq_llm_config(self):
12
- config_list = [
13
- {
14
- "api_type": 'groq',
15
- "model": self.user_controls_input['selected_groq_model'],
16
- "api_key": st.session_state["GROQ_API_KEY"],
17
- "cache_seed": None
18
- }
19
- ]
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  llm_config = {"config_list": config_list, "request_timeout": 60}
22
  st.session_state['llm_config'] = llm_config
 
9
  self.user_controls_input = user_controls_input
10
 
11
  def groq_llm_config(self):
12
+ if st.session_state["Cache_Seed"]:
13
+ config_list = [
14
+ {
15
+ "api_type": 'groq',
16
+ "model": self.user_controls_input['selected_groq_model'],
17
+ "api_key": st.session_state["GROQ_API_KEY"],
18
+ "cache_seed": 41
19
+ }
20
+ ]
21
+
22
+ else :
23
+ config_list = [
24
+ {
25
+ "api_type": 'groq',
26
+ "model": self.user_controls_input['selected_groq_model'],
27
+ "api_key": st.session_state["GROQ_API_KEY"],
28
+ "cache_seed": None
29
+ }
30
+ ]
31
 
32
  llm_config = {"config_list": config_list, "request_timeout": 60}
33
  st.session_state['llm_config'] = llm_config
src/cag/__init__.py ADDED
File without changes
src/cag/cache_manager.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import threading
3
+
4
+ class CacheManager:
5
+ _instance = None
6
+ _lock = threading.Lock()
7
+
8
+ def __new__(cls, max_cache_size=100):
9
+ if cls._instance is None:
10
+ with cls._lock:
11
+ if cls._instance is None:
12
+ cls._instance = super(CacheManager, cls).__new__(cls)
13
+ cls._instance.cache = {}
14
+ cls._instance.max_cache_size = max_cache_size
15
+ return cls._instance
16
+
17
+ def normalize_key(self, key):
18
+ return key.strip().lower()
19
+
20
+ def add_to_cache(self, key, value, embedding=None):
21
+ normalized_key = self.normalize_key(key)
22
+ if len(self.cache) >= self.max_cache_size:
23
+ self.evict_cache()
24
+ self.cache[normalized_key] = {
25
+ "response": value,
26
+ "timestamp": time.time(),
27
+ "embedding": embedding
28
+ }
29
+
30
+ def get_from_cache(self, key):
31
+ normalized_key = self.normalize_key(key)
32
+ return self.cache.get(normalized_key, {}).get("response", None)
33
+
34
+ def get_embedding(self, key):
35
+ normalized_key = self.normalize_key(key)
36
+ return self.cache.get(normalized_key, {}).get("embedding", None)
37
+
38
+ def evict_cache(self):
39
+ if self.cache:
40
+ oldest_key = min(self.cache, key=lambda k: self.cache[k]["timestamp"])
41
+ del self.cache[oldest_key]
42
+
43
+ def clear_cache(self):
44
+ self.cache.clear()
src/cag/embedding_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ import numpy as np
4
+
5
+ class EmbeddingUtils:
6
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
7
+ """
8
+ Initialize the embedding utility with a pre-trained model.
9
+
10
+ Args:
11
+ - model_name (str): Name of the sentence-transformers model.
12
+ """
13
+ self.model = SentenceTransformer(model_name)
14
+
15
+ def generate_embedding(self, text):
16
+ """
17
+ Generate embedding for a given text.
18
+
19
+ Args:
20
+ - text (str): Input text to generate embedding for.
21
+
22
+ Returns:
23
+ - np.ndarray: Embedding vector.
24
+ """
25
+ return self.model.encode([text])[0] # Encode returns a list; we extract the first item
26
+
27
+ def calculate_similarity(self, embedding1, embedding2):
28
+ """
29
+ Calculate cosine similarity between two embeddings.
30
+
31
+ Args:
32
+ - embedding1 (np.ndarray): First embedding vector.
33
+ - embedding2 (np.ndarray): Second embedding vector.
34
+
35
+ Returns:
36
+ - float: Cosine similarity score.
37
+ """
38
+ return cosine_similarity([embedding1], [embedding2])[0][0]
39
+
40
+ def find_best_match(self, query_embedding, cache_embeddings, threshold=0.8):
41
+ """
42
+ Find the best match for a query embedding from a list of cached embeddings.
43
+
44
+ Args:
45
+ - query_embedding (np.ndarray): Embedding of the input query.
46
+ - cache_embeddings (list of np.ndarray): List of cached embeddings.
47
+ - threshold (float): Minimum similarity score to consider a match.
48
+
49
+ Returns:
50
+ - int: Index of the best match if above threshold, otherwise -1.
51
+ """
52
+ if not cache_embeddings:
53
+ return -1 # No cached embeddings to compare
54
+
55
+ similarities = cosine_similarity([query_embedding], cache_embeddings)[0]
56
+ best_match_index = np.argmax(similarities)
57
+ best_match_score = similarities[best_match_index]
58
+
59
+ if best_match_score >= threshold:
60
+ return best_match_index
61
+ return -1
src/cag/generation_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ import os
4
+ import time
5
+ import warnings
6
+ from dotenv import load_dotenv
7
+ from src.cag.cache_manager import CacheManager
8
+ from src.cag.embedding_utils import EmbeddingUtils
9
+
10
+ # Suppress PyTorch Warnings
11
+ warnings.filterwarnings("ignore", message="Tried to instantiate class '__path__._path'")
12
+
13
+ load_dotenv()
14
+
15
+ class LLMIntegration:
16
+ def __init__(self,cache_size=100, similarity_threshold=0.8):
17
+ """Initialize the LLM Integration with API Key, Cache, and Embedding Utilities."""
18
+ self.cache_manager = CacheManager(max_cache_size=cache_size)
19
+ self.embedding_utils = EmbeddingUtils()
20
+ self.similarity_threshold = similarity_threshold
21
+
22
+ def generate_response(self, query,response):
23
+ """Generate a response with cache checking and similarity matching."""
24
+ query_key = self.cache_manager.normalize_key(query)
25
+
26
+ # Check for cache match
27
+ cached_response = self.cache_manager.get_from_cache(query_key)
28
+ if cached_response:
29
+ return f"Cache Hit! {cached_response}"
30
+
31
+ # Generate query embedding
32
+ query_embedding = self.embedding_utils.generate_embedding(query)
33
+
34
+ # Check for approximate match
35
+ best_match_key = self._find_best_match(query_embedding)
36
+ if best_match_key:
37
+ cached_response = self.cache_manager.get_from_cache(best_match_key)
38
+ return f"Cache Hit! {cached_response}"
39
+
40
+ # If no cache match, query the API
41
+ response = response
42
+
43
+ # βœ… Only cache successful responses
44
+ if response :
45
+ self.cache_manager.add_to_cache(query_key, response, embedding=query_embedding)
46
+ return f"Cache Miss! {response}"
47
+ else:
48
+ return "**Error: Could not generate a response.**"
49
+
50
+ def _find_best_match(self, query_embedding):
51
+ """Find the best match in the cache using similarity checking."""
52
+ best_match_key = None
53
+ highest_similarity = 0
54
+
55
+ for key in self.cache_manager.cache:
56
+ cached_embedding = self.cache_manager.get_embedding(key)
57
+ if cached_embedding is not None:
58
+ similarity = self.embedding_utils.calculate_similarity(query_embedding, cached_embedding)
59
+ if similarity > highest_similarity and similarity >= self.similarity_threshold:
60
+ best_match_key = key
61
+ highest_similarity = similarity
62
+ return best_match_key
src/cag/main.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import streamlit as st
5
+ import plotly.express as px
6
+ from datetime import datetime
7
+ from dotenv import load_dotenv
8
+
9
+ # Import the LLM Integration Model
10
+ from src.cag.generation_model import LLMIntegration
11
+
12
+ # Load environment variables and secrets
13
+ load_dotenv()
14
+
15
+ class CAGLLM:
16
+
17
+ def __init__(self,query,response):
18
+ self.query = query
19
+ self.response = response
20
+
21
+
22
+ def process_cag_llm(self):
23
+
24
+ # Initialize LLM Integration with API Key
25
+ llm_system = LLMIntegration()
26
+
27
+ # Cache statistics and tracking initialization
28
+ if "cache_hits" not in st.session_state:
29
+ st.session_state.cache_hits = 0
30
+ st.session_state.cache_misses = 0
31
+ st.session_state.response_times = []
32
+ st.session_state.query_timestamps = []
33
+ st.session_state.history = []
34
+
35
+ # st.set_page_config(
36
+ # page_title="CAG Chatbot",
37
+ # layout="wide",
38
+ # page_icon="πŸ§€",
39
+ # initial_sidebar_state="expanded"
40
+ # )
41
+
42
+ # CSS for Styling Graph
43
+ st.markdown(
44
+ """
45
+ <style>
46
+ body { font-family: 'Arial', sans-serif; }
47
+ .stTextInput, .stButton { border-radius: 8px; }
48
+ .stProgress > div > div { border-radius: 20px; }
49
+ .custom-link { color: #1f77b4; text-decoration: none; font-weight: bold; transition: color 0.3s ease-in-out; }
50
+ .custom-link:hover { color: #ff4b4b; }
51
+ .fixed-graph-container { max-height: 300px !important; overflow-y: auto; }
52
+ </style>
53
+ """,
54
+ unsafe_allow_html=True
55
+ )
56
+
57
+ # Page Title and Description
58
+ st.title("πŸ’‘ Cache Augmented Generation (CAG) Chatbot")
59
+ st.write("**A chatbot with enhanced responses powered by smart caching.**")
60
+
61
+ # Layout Columns: Configurator | Chat | Statistics
62
+ col1, col2, col3 = st.columns([1.2, 2, 1.2])
63
+
64
+ # πŸ› οΈ **Configurator Section (Left Panel)**
65
+ with col1:
66
+ st.header("βš™οΈ Configurator")
67
+ cache_size = st.slider("πŸ—„οΈ Cache Size", min_value=50, max_value=500, value=100)
68
+ similarity_threshold = st.slider("πŸ“ˆ Similarity Threshold", min_value=0.5, max_value=1.0, value=0.8)
69
+ clear_cache = st.button("🧹 Clear Cache")
70
+
71
+ if clear_cache:
72
+ llm_system.cache_manager.clear_cache()
73
+ st.session_state.cache_hits = 0
74
+ st.session_state.cache_misses = 0
75
+ st.session_state.response_times = []
76
+ st.session_state.query_timestamps = []
77
+ st.session_state.history = []
78
+ st.success("βœ… Cache cleared successfully!")
79
+
80
+ # πŸ“¦ **Cache Content Section**
81
+ with st.expander("πŸ“¦ **View Cache Content**"):
82
+ if llm_system.cache_manager.cache:
83
+ for key, value in llm_system.cache_manager.cache.items():
84
+ st.write(f"**Query:** {key}")
85
+ st.write(f"**Response:** {value['response']}")
86
+ st.write(f"**Timestamp:** {datetime.fromtimestamp(value['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}")
87
+ st.write("---")
88
+ else:
89
+ st.write("πŸ—‘οΈ Cache is currently empty.")
90
+
91
+ # πŸ’¬ **Chat Interaction Section (Middle Panel)**
92
+ with col2:
93
+ st.header("πŸ’¬ Chat with CAG")
94
+ query = self.query
95
+ if self.query:
96
+ start_time = time.time()
97
+
98
+ # Step 1: Check Cache
99
+ st.info("⏳ Checking Cache...")
100
+ cached_response = llm_system.cache_manager.get_from_cache(llm_system.cache_manager.normalize_key(query))
101
+
102
+ if cached_response:
103
+ # Step 2: If Cache Hit, Return
104
+ st.success("βœ… Cache Hit! Returning cached response.")
105
+ response = cached_response
106
+ st.session_state.cache_hits += 1
107
+ else:
108
+ # Step 3: If Cache Miss, Query LLM
109
+ st.warning("❌ Cache Miss. Fetching from LLM...")
110
+ response = llm_system.generate_response(query,self.response)
111
+ st.session_state.cache_misses += 1
112
+
113
+ # Response Time and Save Data
114
+ response_time = time.time() - start_time
115
+ st.session_state.response_times.append(response_time)
116
+ st.session_state.query_timestamps.append(datetime.now().strftime('%H:%M:%S'))
117
+ st.session_state.history.append({"query": query, "response": response, "time": response_time})
118
+
119
+ # 🎯 Chat Response
120
+ st.success(f"**πŸ—¨οΈ {response}**")
121
+ st.info(f"⏱️ **Response Time:** {response_time:.2f} seconds")
122
+
123
+ # πŸ“œ **Query History Section**
124
+ with st.expander("πŸ•°οΈ **Query History**"):
125
+ for entry in st.session_state.history[-10:]:
126
+ st.write(f"**Query:** {entry['query']}")
127
+ st.write(f"**Response:** {entry['response']}")
128
+ st.write(f"⏱️ **Time Taken:** {entry['time']:.2f} seconds")
129
+ st.write("---")
130
+
131
+ # πŸ“Š **Cache Statistics Section (Right Panel)**
132
+ with col3:
133
+ st.header("πŸ“Š Cache Statistics")
134
+
135
+ # Real-Time Metrics
136
+ col1_stat, col2_stat, col3_stat = st.columns(3)
137
+ col1_stat.metric("βœ… Hits", st.session_state.cache_hits)
138
+ col2_stat.metric("❌ Misses", st.session_state.cache_misses)
139
+ col3_stat.metric("πŸ“¦ Cache Size", len(llm_system.cache_manager.cache))
140
+
141
+ # Cache Hit/Miss Ratio
142
+ total_queries = st.session_state.cache_hits + st.session_state.cache_misses
143
+ hit_ratio = (st.session_state.cache_hits / total_queries) * 100 if total_queries > 0 else 0
144
+ miss_ratio = (st.session_state.cache_misses / total_queries) * 100 if total_queries > 0 else 0
145
+
146
+ st.progress(hit_ratio / 100, text=f"βœ… Cache Hit Ratio: {hit_ratio:.2f}%")
147
+ st.progress(miss_ratio / 100, text=f"❌ Cache Miss Ratio: {miss_ratio:.2f}%")
148
+
149
+ # πŸ“ˆ **Response Time Graph**
150
+ if st.session_state.response_times:
151
+ st.markdown('<div class="fixed-graph-container">', unsafe_allow_html=True)
152
+ fig = px.line(
153
+ x=st.session_state.query_timestamps,
154
+ y=st.session_state.response_times,
155
+ title="πŸ“ˆ Response Time Trend",
156
+ labels={"x": "Timestamp", "y": "Response Time (s)"}
157
+ )
158
+ st.plotly_chart(fig, use_container_width=True)
159
+ st.markdown('</div>', unsafe_allow_html=True)
160
+
161
+
src/streamlitui/loadui.py CHANGED
@@ -34,7 +34,15 @@ class LoadStreamlitUI:
34
 
35
 
36
 
37
- st.session_state["chat_with_history"] = st.sidebar.toggle("Chat With History")
 
 
 
 
 
 
 
 
38
 
39
  if self.user_controls['selected_usecase'] == "With LLamaIndex Tool":
40
  st.subheader("🏝️ Trip Advisor Specialist using wikipedia")
 
34
 
35
 
36
 
37
+ if st.sidebar.toggle("Chat With History"):
38
+ st.session_state["chat_with_history"]=False
39
+ else :
40
+ st.session_state["chat_with_history"]=True
41
+
42
+ if st.sidebar.toggle("LLM Caching"):
43
+ st.session_state["Cache_Seed"]=True
44
+ else :
45
+ st.session_state["Cache_Seed"]=False
46
 
47
  if self.user_controls['selected_usecase'] == "With LLamaIndex Tool":
48
  st.subheader("🏝️ Trip Advisor Specialist using wikipedia")
src/usecases/basicexample.py CHANGED
@@ -24,7 +24,7 @@ class BasicExample:
24
  asyncio.set_event_loop(self.loop)
25
 
26
  async def initiate_chat(self):
27
- await self.user_proxy.a_initiate_chat(self.assistant, max_turns=2, message=self.problem, clear_history=st.session_state["chat_with_history"])
28
 
29
  def run(self):
30
  self.loop.run_until_complete(self.initiate_chat())
 
24
  asyncio.set_event_loop(self.loop)
25
 
26
  async def initiate_chat(self):
27
+ await self.user_proxy.a_initiate_chat(self.assistant, max_turns=4, message=self.problem, clear_history=st.session_state["chat_with_history"])
28
 
29
  def run(self):
30
  self.loop.run_until_complete(self.initiate_chat())
src/usecases/cag_chat.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from autogen import AssistantAgent, UserProxyAgent
3
+ import streamlit as st
4
+
5
+
6
+ class CAGLLMChat:
7
+ def __init__(self,llm_config,problem):
8
+ self.llm_config = llm_config
9
+ self.problem = problem
10
+
11
+
12
+ def start_chat(self):
13
+ llm_config= self.llm_config
14
+ problem = self.problem
15
+ assistant = AssistantAgent("assistant", llm_config=llm_config,code_execution_config=False,human_input_mode='NEVER')
16
+ user_proxy = UserProxyAgent("user_proxy", code_execution_config=False,human_input_mode='NEVER')
17
+
18
+ # Start the chat
19
+ response = user_proxy.initiate_chat(
20
+ assistant,
21
+ message=problem,
22
+ max_turns=1,
23
+ clear_history=st.session_state["chat_with_history"]
24
+ )
25
+ return response