Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import gradio as gr | |
| import plotly.graph_objs as go | |
| from forecasting.model import forecast_sales | |
| from forecasting.anomaly import detect_anomalies | |
| from forecasting.inventory import recommend_inventory | |
| from llm.chat import ( | |
| get_llm_response, clear_chat_memory, get_chat_history, chat_instance, | |
| get_knowledge_base_response, search_knowledge_base, get_vector_store_stats | |
| ) | |
| from llm.prompts import ( | |
| FORECAST_EXPLANATION_TEMPLATE, | |
| ANOMALY_EXPLANATION_TEMPLATE, | |
| INVENTORY_RECOMMENDATION_TEMPLATE, | |
| BUSINESS_INSIGHTS_TEMPLATE, | |
| SYSTEM_PROMPT | |
| ) | |
| from llm.retail_chain import RetailAnalysisChain, SalesComparisonChain, create_retail_workflow | |
| from llm.vector_store import initialize_vector_store | |
| import json | |
| from datetime import datetime | |
| # Custom JSON encoder to handle pandas Timestamp objects | |
| class TimestampEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if pd.isna(obj): | |
| return None | |
| elif isinstance(obj, pd.Timestamp): | |
| return obj.isoformat() | |
| elif isinstance(obj, datetime): | |
| return obj.isoformat() | |
| elif isinstance(obj, pd.Series): | |
| return obj.tolist() | |
| elif isinstance(obj, pd.DataFrame): | |
| return obj.to_dict('records') | |
| return super().default(obj) | |
| # Load large dataset | |
| DATA_PATH = 'data/sales_large.csv' | |
| df = pd.read_csv(DATA_PATH) | |
| # Ensure date column is properly parsed as datetime | |
| df['date'] = pd.to_datetime(df['date']) | |
| stores = df['store'].unique().tolist() | |
| products = df['product'].unique().tolist() | |
| categories = df['category'].unique().tolist() | |
| regions = df['region'].unique().tolist() | |
| # Initialize vector store | |
| vector_store = initialize_vector_store() | |
| def plot_forecast(store, product): | |
| """Create enhanced Plotly chart with forecast, anomalies, and inventory.""" | |
| forecast = forecast_sales(df, store, product, periods=14) # Extended forecast period | |
| anomalies = detect_anomalies(df, store, product) | |
| inventory = recommend_inventory(forecast) | |
| # Plotly chart | |
| fig = go.Figure() | |
| # Historical sales | |
| hist = df[(df['store'] == store) & (df['product'] == product)] | |
| fig.add_trace(go.Scatter( | |
| x=hist['date'], | |
| y=hist['sales'], | |
| mode='lines+markers', | |
| name='Actual Sales', | |
| line=dict(color='blue', width=2), | |
| marker=dict(size=6) | |
| )) | |
| # Forecast | |
| fig.add_trace(go.Scatter( | |
| x=forecast['ds'], | |
| y=forecast['yhat'], | |
| mode='lines', | |
| name='Forecast', | |
| line=dict(color='green', width=2, dash='dash') | |
| )) | |
| # Forecast confidence interval | |
| fig.add_trace(go.Scatter( | |
| x=forecast['ds'].tolist() + forecast['ds'].tolist()[::-1], | |
| y=forecast['yhat_upper'].tolist() + forecast['yhat_lower'].tolist()[::-1], | |
| fill='toself', | |
| fillcolor='rgba(0,255,0,0.2)', | |
| line=dict(color='rgba(255,255,255,0)'), | |
| name='Forecast Confidence', | |
| showlegend=False | |
| )) | |
| # Anomalies | |
| anom_points = anomalies[anomalies['anomaly']] | |
| if not anom_points.empty: | |
| fig.add_trace(go.Scatter( | |
| x=anom_points['date'], | |
| y=anom_points['sales'], | |
| mode='markers', | |
| name='Anomalies', | |
| marker=dict(color='red', size=12, symbol='x') | |
| )) | |
| # Inventory recommendation | |
| fig.add_trace(go.Scatter( | |
| x=inventory['ds'], | |
| y=inventory['recommended_inventory'], | |
| mode='lines', | |
| name='Recommended Inventory', | |
| line=dict(color='orange', width=2, dash='dot') | |
| )) | |
| fig.update_layout( | |
| title=f"Sales Forecast & Inventory Analysis: {product} at {store}", | |
| xaxis_title='Date', | |
| yaxis_title='Sales / Inventory Units', | |
| hovermode='x unified', | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) | |
| ) | |
| return fig | |
| def enhanced_chat_with_llm(message, store, product, use_knowledge_base=False): | |
| """Enhanced chat function using LangChain features and vector store.""" | |
| try: | |
| # Check if user wants knowledge base information | |
| knowledge_keywords = ['best practice', 'guideline', 'method', 'strategy', 'kpi', 'metric', 'formula', 'calculation'] | |
| if use_knowledge_base or any(keyword in message.lower() for keyword in knowledge_keywords): | |
| return get_knowledge_base_response(message) | |
| # Get current data context | |
| hist_data = df[(df['store'] == store) & (df['product'] == product)] | |
| forecast_data = forecast_sales(df, store, product, periods=14) | |
| anomalies_data = detect_anomalies(df, store, product) | |
| # Create context for LLM | |
| context = { | |
| "store": store, | |
| "product": product, | |
| "historical_sales": hist_data['sales'].tolist(), | |
| "forecast": forecast_data['yhat'].tail(14).tolist(), | |
| "anomalies": anomalies_data[anomalies_data['anomaly']].to_dict('records'), | |
| "category": hist_data['category'].iloc[0] if not hist_data.empty else "Unknown", | |
| "region": hist_data['region'].iloc[0] if not hist_data.empty else "Unknown" | |
| } | |
| # Route based on message content | |
| if any(keyword in message.lower() for keyword in ['forecast', 'prediction', 'trend']): | |
| prompt = FORECAST_EXPLANATION_TEMPLATE.format( | |
| system_prompt=SYSTEM_PROMPT, | |
| product=product, | |
| store=store, | |
| forecast_data=json.dumps(forecast_data.tail(14).to_dict('records'), indent=2, cls=TimestampEncoder), | |
| historical_data=json.dumps(hist_data.to_dict('records'), indent=2, cls=TimestampEncoder) | |
| ) | |
| response = chat_instance.conversation_chain.predict(input=prompt) | |
| elif any(keyword in message.lower() for keyword in ['anomaly', 'unusual', 'spike', 'dip']): | |
| anom_data = anomalies_data[anomalies_data['anomaly']] | |
| if not anom_data.empty: | |
| latest_anom = anom_data.iloc[-1] | |
| prompt = ANOMALY_EXPLANATION_TEMPLATE.format( | |
| system_prompt=SYSTEM_PROMPT, | |
| product=product, | |
| store=store, | |
| anomaly_data=json.dumps(latest_anom.to_dict(), indent=2, cls=TimestampEncoder), | |
| date=latest_anom['date'] | |
| ) | |
| response = chat_instance.get_response(prompt, context=json.dumps(context, cls=TimestampEncoder)) | |
| else: | |
| response = "No anomalies detected in the current data for this product and store." | |
| elif any(keyword in message.lower() for keyword in ['inventory', 'stock', 'reorder']): | |
| current_inventory = 50 # Placeholder - you'd get this from your inventory system | |
| safety_stock = 10 | |
| prompt = INVENTORY_RECOMMENDATION_TEMPLATE.format( | |
| system_prompt=SYSTEM_PROMPT, | |
| product=product, | |
| store=store, | |
| forecast=json.dumps(forecast_data.tail(14)['yhat'].tolist(), cls=TimestampEncoder), | |
| current_inventory=current_inventory, | |
| safety_stock=safety_stock | |
| ) | |
| response = chat_instance.get_response(prompt, context=json.dumps(context, cls=TimestampEncoder)) | |
| elif any(keyword in message.lower() for keyword in ['compare', 'vs', 'versus']): | |
| # Handle comparison requests | |
| if len(stores) > 1: | |
| store_b = [s for s in stores if s != store][0] | |
| hist_b = df[(df['store'] == store_b) & (df['product'] == product)] | |
| comparison_chain = SalesComparisonChain(chat_instance.llm) | |
| result = comparison_chain.run({ | |
| "store_a": store, | |
| "store_b": store_b, | |
| "product": product, | |
| "sales_data_a": hist_data['sales'].tolist(), | |
| "sales_data_b": hist_b['sales'].tolist() | |
| }) | |
| response = result.get("comparison_analysis", "Comparison analysis completed.") | |
| else: | |
| response = "Need at least two stores for comparison." | |
| else: | |
| # General business insights | |
| data_summary = { | |
| "total_sales": hist_data['sales'].sum(), | |
| "avg_sales": hist_data['sales'].mean(), | |
| "sales_trend": "increasing" if hist_data['sales'].iloc[-1] > hist_data['sales'].iloc[0] else "decreasing", | |
| "forecast_next_week": forecast_data['yhat'].iloc[-1], | |
| "category": context.get("category", "Unknown"), | |
| "region": context.get("region", "Unknown") | |
| } | |
| prompt = BUSINESS_INSIGHTS_TEMPLATE.format( | |
| system_prompt=SYSTEM_PROMPT, | |
| data_summary=json.dumps(data_summary, indent=2, cls=TimestampEncoder), | |
| user_question=message | |
| ) | |
| # Use conversation chain to maintain memory | |
| response = chat_instance.conversation_chain.predict(input=prompt) | |
| return response | |
| except Exception as e: | |
| return f"I encountered an error while processing your request: {str(e)}" | |
| def gradio_interface(store, product, message, use_kb=False, clear_memory=False): | |
| """Enhanced Gradio interface with memory management and knowledge base.""" | |
| if clear_memory: | |
| clear_chat_memory() | |
| return plot_forecast(store, product), "Chat memory cleared! Ask me anything about the data." | |
| fig = plot_forecast(store, product) | |
| chat_response = enhanced_chat_with_llm(message, store, product, use_kb) if message else "Hello! I'm your retail analytics assistant. Ask me about forecasts, anomalies, inventory, or any business insights!" | |
| return fig, chat_response | |
| def get_chat_history_display(): | |
| """Display chat history in a readable format.""" | |
| history = get_chat_history() | |
| if history and len(history) > 0: | |
| formatted_history = [] | |
| for msg in history: | |
| if hasattr(msg, 'type') and hasattr(msg, 'content'): | |
| formatted_history.append(f"{msg.type}: {msg.content}") | |
| elif hasattr(msg, 'role') and hasattr(msg, 'content'): | |
| formatted_history.append(f"{msg.role}: {msg.content}") | |
| else: | |
| formatted_history.append(str(msg)) | |
| return "\n".join(formatted_history) | |
| return "No chat history yet." | |
| def get_data_summary(): | |
| """Get summary statistics of the dataset.""" | |
| summary = { | |
| "Total Records": len(df), | |
| "Stores": len(stores), | |
| "Products": len(products), | |
| "Categories": len(categories), | |
| "Regions": len(regions), | |
| "Date Range": f"{df['date'].min()} to {df['date'].max()}", | |
| "Total Sales": df['sales'].sum(), | |
| "Average Sales": df['sales'].mean() | |
| } | |
| return json.dumps(summary, indent=2, cls=TimestampEncoder) | |
| def get_vector_store_info(): | |
| """Get vector store information.""" | |
| stats = get_vector_store_stats() | |
| return json.dumps(stats, indent=2, cls=TimestampEncoder) | |
| # Create enhanced Gradio interface | |
| with gr.Blocks(title="Retail Demand Forecasting Dashboard", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ποΈ Retail Demand Forecasting Dashboard") | |
| gr.Markdown("### AI-Powered Sales Analytics with LangChain & Vector Store") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Data Selection") | |
| store_input = gr.Dropdown( | |
| choices=stores, | |
| label="πͺ Store", | |
| value=stores[0], | |
| info="Select a store to analyze" | |
| ) | |
| product_input = gr.Dropdown( | |
| choices=products, | |
| label="π¦ Product", | |
| value=products[0], | |
| info="Select a product to analyze" | |
| ) | |
| gr.Markdown("### π¬ AI Assistant") | |
| message_input = gr.Textbox( | |
| label="Ask me anything about the data...", | |
| placeholder="e.g., 'Explain the forecast', 'What anomalies do you see?', 'Best practices for inventory management'", | |
| lines=3 | |
| ) | |
| use_kb_checkbox = gr.Checkbox( | |
| label="π Use Knowledge Base", | |
| value=False, | |
| info="Check to search retail knowledge base for best practices and guidelines" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Analyze", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Memory", variant="secondary") | |
| gr.Markdown("### π Chat History") | |
| history_btn = gr.Button("π Show History") | |
| history_output = gr.Textbox(label="Conversation History", lines=5, interactive=False) | |
| gr.Markdown("### π System Info") | |
| with gr.Row(): | |
| data_info_btn = gr.Button("π Data Summary") | |
| vector_info_btn = gr.Button("π Vector Store Info") | |
| info_output = gr.Textbox(label="System Information", lines=8, interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Analytics Dashboard") | |
| chart_output = gr.Plot(label="Forecast & Inventory Analysis") | |
| chat_output = gr.Textbox( | |
| label="π€ AI Response", | |
| lines=10, | |
| interactive=False | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=gradio_interface, | |
| inputs=[store_input, product_input, message_input, use_kb_checkbox], | |
| outputs=[chart_output, chat_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: gradio_interface(store_input.value, product_input.value, "", False, True), | |
| outputs=[chart_output, chat_output] | |
| ) | |
| history_btn.click( | |
| fn=get_chat_history_display, | |
| outputs=[history_output] | |
| ) | |
| data_info_btn.click( | |
| fn=get_data_summary, | |
| outputs=[info_output] | |
| ) | |
| vector_info_btn.click( | |
| fn=get_vector_store_info, | |
| outputs=[info_output] | |
| ) | |
| # Auto-update chart when store/product changes | |
| store_input.change( | |
| fn=lambda s, p: (plot_forecast(s, p), "Select a store and product, then ask me anything!"), | |
| inputs=[store_input, product_input], | |
| outputs=[chart_output, chat_output] | |
| ) | |
| product_input.change( | |
| fn=lambda s, p: (plot_forecast(s, p), "Select a store and product, then ask me anything!"), | |
| inputs=[store_input, product_input], | |
| outputs=[chart_output, chat_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True, debug=True) |