DataSprint / app.py
sujana05's picture
Upload folder using huggingface_hub
b4ce589 verified
import pandas as pd
import gradio as gr
import plotly.graph_objs as go
from forecasting.model import forecast_sales
from forecasting.anomaly import detect_anomalies
from forecasting.inventory import recommend_inventory
from llm.chat import (
get_llm_response, clear_chat_memory, get_chat_history, chat_instance,
get_knowledge_base_response, search_knowledge_base, get_vector_store_stats
)
from llm.prompts import (
FORECAST_EXPLANATION_TEMPLATE,
ANOMALY_EXPLANATION_TEMPLATE,
INVENTORY_RECOMMENDATION_TEMPLATE,
BUSINESS_INSIGHTS_TEMPLATE,
SYSTEM_PROMPT
)
from llm.retail_chain import RetailAnalysisChain, SalesComparisonChain, create_retail_workflow
from llm.vector_store import initialize_vector_store
import json
from datetime import datetime
# Custom JSON encoder to handle pandas Timestamp objects
class TimestampEncoder(json.JSONEncoder):
def default(self, obj):
if pd.isna(obj):
return None
elif isinstance(obj, pd.Timestamp):
return obj.isoformat()
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, pd.Series):
return obj.tolist()
elif isinstance(obj, pd.DataFrame):
return obj.to_dict('records')
return super().default(obj)
# Load large dataset
DATA_PATH = 'data/sales_large.csv'
df = pd.read_csv(DATA_PATH)
# Ensure date column is properly parsed as datetime
df['date'] = pd.to_datetime(df['date'])
stores = df['store'].unique().tolist()
products = df['product'].unique().tolist()
categories = df['category'].unique().tolist()
regions = df['region'].unique().tolist()
# Initialize vector store
vector_store = initialize_vector_store()
def plot_forecast(store, product):
"""Create enhanced Plotly chart with forecast, anomalies, and inventory."""
forecast = forecast_sales(df, store, product, periods=14) # Extended forecast period
anomalies = detect_anomalies(df, store, product)
inventory = recommend_inventory(forecast)
# Plotly chart
fig = go.Figure()
# Historical sales
hist = df[(df['store'] == store) & (df['product'] == product)]
fig.add_trace(go.Scatter(
x=hist['date'],
y=hist['sales'],
mode='lines+markers',
name='Actual Sales',
line=dict(color='blue', width=2),
marker=dict(size=6)
))
# Forecast
fig.add_trace(go.Scatter(
x=forecast['ds'],
y=forecast['yhat'],
mode='lines',
name='Forecast',
line=dict(color='green', width=2, dash='dash')
))
# Forecast confidence interval
fig.add_trace(go.Scatter(
x=forecast['ds'].tolist() + forecast['ds'].tolist()[::-1],
y=forecast['yhat_upper'].tolist() + forecast['yhat_lower'].tolist()[::-1],
fill='toself',
fillcolor='rgba(0,255,0,0.2)',
line=dict(color='rgba(255,255,255,0)'),
name='Forecast Confidence',
showlegend=False
))
# Anomalies
anom_points = anomalies[anomalies['anomaly']]
if not anom_points.empty:
fig.add_trace(go.Scatter(
x=anom_points['date'],
y=anom_points['sales'],
mode='markers',
name='Anomalies',
marker=dict(color='red', size=12, symbol='x')
))
# Inventory recommendation
fig.add_trace(go.Scatter(
x=inventory['ds'],
y=inventory['recommended_inventory'],
mode='lines',
name='Recommended Inventory',
line=dict(color='orange', width=2, dash='dot')
))
fig.update_layout(
title=f"Sales Forecast & Inventory Analysis: {product} at {store}",
xaxis_title='Date',
yaxis_title='Sales / Inventory Units',
hovermode='x unified',
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
return fig
def enhanced_chat_with_llm(message, store, product, use_knowledge_base=False):
"""Enhanced chat function using LangChain features and vector store."""
try:
# Check if user wants knowledge base information
knowledge_keywords = ['best practice', 'guideline', 'method', 'strategy', 'kpi', 'metric', 'formula', 'calculation']
if use_knowledge_base or any(keyword in message.lower() for keyword in knowledge_keywords):
return get_knowledge_base_response(message)
# Get current data context
hist_data = df[(df['store'] == store) & (df['product'] == product)]
forecast_data = forecast_sales(df, store, product, periods=14)
anomalies_data = detect_anomalies(df, store, product)
# Create context for LLM
context = {
"store": store,
"product": product,
"historical_sales": hist_data['sales'].tolist(),
"forecast": forecast_data['yhat'].tail(14).tolist(),
"anomalies": anomalies_data[anomalies_data['anomaly']].to_dict('records'),
"category": hist_data['category'].iloc[0] if not hist_data.empty else "Unknown",
"region": hist_data['region'].iloc[0] if not hist_data.empty else "Unknown"
}
# Route based on message content
if any(keyword in message.lower() for keyword in ['forecast', 'prediction', 'trend']):
prompt = FORECAST_EXPLANATION_TEMPLATE.format(
system_prompt=SYSTEM_PROMPT,
product=product,
store=store,
forecast_data=json.dumps(forecast_data.tail(14).to_dict('records'), indent=2, cls=TimestampEncoder),
historical_data=json.dumps(hist_data.to_dict('records'), indent=2, cls=TimestampEncoder)
)
response = chat_instance.conversation_chain.predict(input=prompt)
elif any(keyword in message.lower() for keyword in ['anomaly', 'unusual', 'spike', 'dip']):
anom_data = anomalies_data[anomalies_data['anomaly']]
if not anom_data.empty:
latest_anom = anom_data.iloc[-1]
prompt = ANOMALY_EXPLANATION_TEMPLATE.format(
system_prompt=SYSTEM_PROMPT,
product=product,
store=store,
anomaly_data=json.dumps(latest_anom.to_dict(), indent=2, cls=TimestampEncoder),
date=latest_anom['date']
)
response = chat_instance.get_response(prompt, context=json.dumps(context, cls=TimestampEncoder))
else:
response = "No anomalies detected in the current data for this product and store."
elif any(keyword in message.lower() for keyword in ['inventory', 'stock', 'reorder']):
current_inventory = 50 # Placeholder - you'd get this from your inventory system
safety_stock = 10
prompt = INVENTORY_RECOMMENDATION_TEMPLATE.format(
system_prompt=SYSTEM_PROMPT,
product=product,
store=store,
forecast=json.dumps(forecast_data.tail(14)['yhat'].tolist(), cls=TimestampEncoder),
current_inventory=current_inventory,
safety_stock=safety_stock
)
response = chat_instance.get_response(prompt, context=json.dumps(context, cls=TimestampEncoder))
elif any(keyword in message.lower() for keyword in ['compare', 'vs', 'versus']):
# Handle comparison requests
if len(stores) > 1:
store_b = [s for s in stores if s != store][0]
hist_b = df[(df['store'] == store_b) & (df['product'] == product)]
comparison_chain = SalesComparisonChain(chat_instance.llm)
result = comparison_chain.run({
"store_a": store,
"store_b": store_b,
"product": product,
"sales_data_a": hist_data['sales'].tolist(),
"sales_data_b": hist_b['sales'].tolist()
})
response = result.get("comparison_analysis", "Comparison analysis completed.")
else:
response = "Need at least two stores for comparison."
else:
# General business insights
data_summary = {
"total_sales": hist_data['sales'].sum(),
"avg_sales": hist_data['sales'].mean(),
"sales_trend": "increasing" if hist_data['sales'].iloc[-1] > hist_data['sales'].iloc[0] else "decreasing",
"forecast_next_week": forecast_data['yhat'].iloc[-1],
"category": context.get("category", "Unknown"),
"region": context.get("region", "Unknown")
}
prompt = BUSINESS_INSIGHTS_TEMPLATE.format(
system_prompt=SYSTEM_PROMPT,
data_summary=json.dumps(data_summary, indent=2, cls=TimestampEncoder),
user_question=message
)
# Use conversation chain to maintain memory
response = chat_instance.conversation_chain.predict(input=prompt)
return response
except Exception as e:
return f"I encountered an error while processing your request: {str(e)}"
def gradio_interface(store, product, message, use_kb=False, clear_memory=False):
"""Enhanced Gradio interface with memory management and knowledge base."""
if clear_memory:
clear_chat_memory()
return plot_forecast(store, product), "Chat memory cleared! Ask me anything about the data."
fig = plot_forecast(store, product)
chat_response = enhanced_chat_with_llm(message, store, product, use_kb) if message else "Hello! I'm your retail analytics assistant. Ask me about forecasts, anomalies, inventory, or any business insights!"
return fig, chat_response
def get_chat_history_display():
"""Display chat history in a readable format."""
history = get_chat_history()
if history and len(history) > 0:
formatted_history = []
for msg in history:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
formatted_history.append(f"{msg.type}: {msg.content}")
elif hasattr(msg, 'role') and hasattr(msg, 'content'):
formatted_history.append(f"{msg.role}: {msg.content}")
else:
formatted_history.append(str(msg))
return "\n".join(formatted_history)
return "No chat history yet."
def get_data_summary():
"""Get summary statistics of the dataset."""
summary = {
"Total Records": len(df),
"Stores": len(stores),
"Products": len(products),
"Categories": len(categories),
"Regions": len(regions),
"Date Range": f"{df['date'].min()} to {df['date'].max()}",
"Total Sales": df['sales'].sum(),
"Average Sales": df['sales'].mean()
}
return json.dumps(summary, indent=2, cls=TimestampEncoder)
def get_vector_store_info():
"""Get vector store information."""
stats = get_vector_store_stats()
return json.dumps(stats, indent=2, cls=TimestampEncoder)
# Create enhanced Gradio interface
with gr.Blocks(title="Retail Demand Forecasting Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ›οΈ Retail Demand Forecasting Dashboard")
gr.Markdown("### AI-Powered Sales Analytics with LangChain & Vector Store")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Data Selection")
store_input = gr.Dropdown(
choices=stores,
label="πŸͺ Store",
value=stores[0],
info="Select a store to analyze"
)
product_input = gr.Dropdown(
choices=products,
label="πŸ“¦ Product",
value=products[0],
info="Select a product to analyze"
)
gr.Markdown("### πŸ’¬ AI Assistant")
message_input = gr.Textbox(
label="Ask me anything about the data...",
placeholder="e.g., 'Explain the forecast', 'What anomalies do you see?', 'Best practices for inventory management'",
lines=3
)
use_kb_checkbox = gr.Checkbox(
label="πŸ” Use Knowledge Base",
value=False,
info="Check to search retail knowledge base for best practices and guidelines"
)
with gr.Row():
submit_btn = gr.Button("πŸš€ Analyze", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Memory", variant="secondary")
gr.Markdown("### πŸ“ Chat History")
history_btn = gr.Button("πŸ“‹ Show History")
history_output = gr.Textbox(label="Conversation History", lines=5, interactive=False)
gr.Markdown("### πŸ“ˆ System Info")
with gr.Row():
data_info_btn = gr.Button("πŸ“Š Data Summary")
vector_info_btn = gr.Button("πŸ” Vector Store Info")
info_output = gr.Textbox(label="System Information", lines=8, interactive=False)
with gr.Column(scale=2):
gr.Markdown("### πŸ“ˆ Analytics Dashboard")
chart_output = gr.Plot(label="Forecast & Inventory Analysis")
chat_output = gr.Textbox(
label="πŸ€– AI Response",
lines=10,
interactive=False
)
# Event handlers
submit_btn.click(
fn=gradio_interface,
inputs=[store_input, product_input, message_input, use_kb_checkbox],
outputs=[chart_output, chat_output]
)
clear_btn.click(
fn=lambda: gradio_interface(store_input.value, product_input.value, "", False, True),
outputs=[chart_output, chat_output]
)
history_btn.click(
fn=get_chat_history_display,
outputs=[history_output]
)
data_info_btn.click(
fn=get_data_summary,
outputs=[info_output]
)
vector_info_btn.click(
fn=get_vector_store_info,
outputs=[info_output]
)
# Auto-update chart when store/product changes
store_input.change(
fn=lambda s, p: (plot_forecast(s, p), "Select a store and product, then ask me anything!"),
inputs=[store_input, product_input],
outputs=[chart_output, chat_output]
)
product_input.change(
fn=lambda s, p: (plot_forecast(s, p), "Select a store and product, then ask me anything!"),
inputs=[store_input, product_input],
outputs=[chart_output, chat_output]
)
if __name__ == "__main__":
demo.launch(share=True, debug=True)