Spaces:
Sleeping
Sleeping
File size: 6,938 Bytes
3392ab1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import sys
import os
import time
import streamlit as st
import plotly.express as px
from datetime import datetime
from dotenv import load_dotenv
# Import the LLM Integration Model
from src.cag.generation_model import LLMIntegration
# Load environment variables and secrets
load_dotenv()
class CAGLLM:
def __init__(self,query,response):
self.query = query
self.response = response
def process_cag_llm(self):
# Initialize LLM Integration with API Key
llm_system = LLMIntegration()
# Cache statistics and tracking initialization
if "cache_hits" not in st.session_state:
st.session_state.cache_hits = 0
st.session_state.cache_misses = 0
st.session_state.response_times = []
st.session_state.query_timestamps = []
st.session_state.history = []
# st.set_page_config(
# page_title="CAG Chatbot",
# layout="wide",
# page_icon="π§",
# initial_sidebar_state="expanded"
# )
# CSS for Styling Graph
st.markdown(
"""
<style>
body { font-family: 'Arial', sans-serif; }
.stTextInput, .stButton { border-radius: 8px; }
.stProgress > div > div { border-radius: 20px; }
.custom-link { color: #1f77b4; text-decoration: none; font-weight: bold; transition: color 0.3s ease-in-out; }
.custom-link:hover { color: #ff4b4b; }
.fixed-graph-container { max-height: 300px !important; overflow-y: auto; }
</style>
""",
unsafe_allow_html=True
)
# Page Title and Description
st.title("π‘ Cache Augmented Generation (CAG) Chatbot")
st.write("**A chatbot with enhanced responses powered by smart caching.**")
# Layout Columns: Configurator | Chat | Statistics
col1, col2, col3 = st.columns([1.2, 2, 1.2])
# π οΈ **Configurator Section (Left Panel)**
with col1:
st.header("βοΈ Configurator")
cache_size = st.slider("ποΈ Cache Size", min_value=50, max_value=500, value=100)
similarity_threshold = st.slider("π Similarity Threshold", min_value=0.5, max_value=1.0, value=0.8)
clear_cache = st.button("π§Ή Clear Cache")
if clear_cache:
llm_system.cache_manager.clear_cache()
st.session_state.cache_hits = 0
st.session_state.cache_misses = 0
st.session_state.response_times = []
st.session_state.query_timestamps = []
st.session_state.history = []
st.success("β
Cache cleared successfully!")
# π¦ **Cache Content Section**
with st.expander("π¦ **View Cache Content**"):
if llm_system.cache_manager.cache:
for key, value in llm_system.cache_manager.cache.items():
st.write(f"**Query:** {key}")
st.write(f"**Response:** {value['response']}")
st.write(f"**Timestamp:** {datetime.fromtimestamp(value['timestamp']).strftime('%Y-%m-%d %H:%M:%S')}")
st.write("---")
else:
st.write("ποΈ Cache is currently empty.")
# π¬ **Chat Interaction Section (Middle Panel)**
with col2:
st.header("π¬ Chat with CAG")
query = self.query
if self.query:
start_time = time.time()
# Step 1: Check Cache
st.info("β³ Checking Cache...")
cached_response = llm_system.cache_manager.get_from_cache(llm_system.cache_manager.normalize_key(query))
if cached_response:
# Step 2: If Cache Hit, Return
st.success("β
Cache Hit! Returning cached response.")
response = cached_response
st.session_state.cache_hits += 1
else:
# Step 3: If Cache Miss, Query LLM
st.warning("β Cache Miss. Fetching from LLM...")
response = llm_system.generate_response(query,self.response)
st.session_state.cache_misses += 1
# Response Time and Save Data
response_time = time.time() - start_time
st.session_state.response_times.append(response_time)
st.session_state.query_timestamps.append(datetime.now().strftime('%H:%M:%S'))
st.session_state.history.append({"query": query, "response": response, "time": response_time})
# π― Chat Response
st.success(f"**π¨οΈ {response}**")
st.info(f"β±οΈ **Response Time:** {response_time:.2f} seconds")
# π **Query History Section**
with st.expander("π°οΈ **Query History**"):
for entry in st.session_state.history[-10:]:
st.write(f"**Query:** {entry['query']}")
st.write(f"**Response:** {entry['response']}")
st.write(f"β±οΈ **Time Taken:** {entry['time']:.2f} seconds")
st.write("---")
# π **Cache Statistics Section (Right Panel)**
with col3:
st.header("π Cache Statistics")
# Real-Time Metrics
col1_stat, col2_stat, col3_stat = st.columns(3)
col1_stat.metric("β
Hits", st.session_state.cache_hits)
col2_stat.metric("β Misses", st.session_state.cache_misses)
col3_stat.metric("π¦ Cache Size", len(llm_system.cache_manager.cache))
# Cache Hit/Miss Ratio
total_queries = st.session_state.cache_hits + st.session_state.cache_misses
hit_ratio = (st.session_state.cache_hits / total_queries) * 100 if total_queries > 0 else 0
miss_ratio = (st.session_state.cache_misses / total_queries) * 100 if total_queries > 0 else 0
st.progress(hit_ratio / 100, text=f"β
Cache Hit Ratio: {hit_ratio:.2f}%")
st.progress(miss_ratio / 100, text=f"β Cache Miss Ratio: {miss_ratio:.2f}%")
# π **Response Time Graph**
if st.session_state.response_times:
st.markdown('<div class="fixed-graph-container">', unsafe_allow_html=True)
fig = px.line(
x=st.session_state.query_timestamps,
y=st.session_state.response_times,
title="π Response Time Trend",
labels={"x": "Timestamp", "y": "Response Time (s)"}
)
st.plotly_chart(fig, use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
|