akwel_performance / src /langgraph_workflow.py
ArkenB's picture
Update src/langgraph_workflow.py
be0b057 verified
# src/langgraph_workflow.py
from typing import TypedDict, List, Dict
import pandas as pd
from langgraph.graph import StateGraph, END
import os
import time
import re
from pathlib import Path
import base64
# Import your custom modules
from src.data_loader import load_data
from src.preprocessor import preprocess_data, get_feedback_distribution, get_instructor_rating_distribution, get_average_scores
from src.chart_generator import (
plot_feedback_distribution_per_subject,
plot_instructor_rating_distribution_per_subject,
plot_avg_scores_per_subject,
plot_avg_scores_per_department,
plot_correlation_heatmap,
plot_radar_chart_subject_department # Ensure this is imported
)
from src.analysis import calculate_correlations
from src.report_generator_llm import generate_analysis_text
# Define the state
class WorkflowState(TypedDict):
file_path: str
raw_df: pd.DataFrame | None
processed_df: pd.DataFrame | None
feedback_distribution: dict
instructor_rating_distribution: dict
average_scores: dict
correlations: dict
charts_output_dir: str # Directory where charts are saved
chart_filepaths: Dict[str, str] # Map logical name to filepath
analysis_text: str # Raw text analysis from LLM
report_path: str # Path to the final assembled HTML report
error_message: str | None
# Define Node Functions
def load_data_node(state: WorkflowState) -> WorkflowState:
print("---NODE: LOAD DATA---")
df = load_data(state["file_path"])
if df is None:
return {**state, "error_message": "Failed to load data."}
return {**state, "raw_df": df, "error_message": None}
def preprocess_data_node(state: WorkflowState) -> WorkflowState:
print("---NODE: PREPROCESS DATA---")
if state["raw_df"] is None:
return {**state, "error_message": "Preprocessing skipped: No raw data."}
processed_df = preprocess_data(state["raw_df"].copy()) # Use .copy()
feedback_dist = get_feedback_distribution(processed_df)
instructor_rating_dist = get_instructor_rating_distribution(processed_df)
avg_scores = get_average_scores(processed_df) # This returns a dict of DataFrames
return {
**state,
"processed_df": processed_df,
"feedback_distribution": feedback_dist,
"instructor_rating_distribution": instructor_rating_dist,
"average_scores": avg_scores
}
def generate_charts_node(state: WorkflowState) -> WorkflowState:
print("---NODE: GENERATE CHARTS---")
if state["processed_df"] is None or state["average_scores"] is None:
return {**state, "error_message": "Chart generation skipped: Missing processed data or average scores."}
charts = {}
df = state["processed_df"]
feedback_dist = state["feedback_distribution"]
instructor_rating_dist = state["instructor_rating_distribution"]
avg_scores_data = state["average_scores"] # This is a dict
# Subject-wise distribution charts
for subject in df['Subject'].unique():
fb_chart = plot_feedback_distribution_per_subject(feedback_dist, subject)
if fb_chart: charts[f'feedback_dist_{subject.lower().replace(" ", "_")}'] = fb_chart
ir_chart = plot_instructor_rating_distribution_per_subject(instructor_rating_dist, subject)
if ir_chart: charts[f'instructor_rating_dist_{subject.lower().replace(" ", "_")}'] = ir_chart
# Average scores charts
charts['avg_scores_per_subject'] = plot_avg_scores_per_subject(avg_scores_data['avg_scores_subject'])
charts['avg_scores_per_department'] = plot_avg_scores_per_department(avg_scores_data['avg_scores_dept'])
# Correlation heatmap
charts['correlation_heatmap_overall'] = plot_correlation_heatmap(df, "Overall")
# for subject in df['Subject'].unique():
# corr_chart = plot_correlation_heatmap(df, subject)
# if corr_chart: charts[f'correlation_heatmap_{subject.lower().replace(" ", "_")}'] = corr_chart
# Radar charts per department
for dept in df['Department'].unique():
radar_chart = plot_radar_chart_subject_department(avg_scores_data['avg_scores_subject_dept'], dept)
if radar_chart: charts[f'radar_chart_{dept.lower().replace(" ", "_")}'] = radar_chart
return {**state, "charts_b64": charts}
def analyze_data_node(state: WorkflowState) -> WorkflowState:
print("---NODE: ANALYZE DATA (CORRELATIONS)---")
if state["processed_df"] is None:
return {**state, "error_message": "Analysis skipped: Missing processed data."}
correlations = calculate_correlations(state["processed_df"])
return {**state, "correlations": correlations}
def generate_charts_node(state: WorkflowState) -> WorkflowState:
print("---NODE: GENERATE CHARTS (Save as PNG)---")
if state["processed_df"] is None or state["average_scores"] is None:
return {**state, "error_message": "Chart generation skipped: Missing processed data or average scores."}
# Create a unique directory for this run's charts
timestamp = time.strftime("%Y%m%d_%H%M%S")
charts_output_dir = os.path.join("outputs", f"charts_{timestamp}")
os.makedirs(charts_output_dir, exist_ok=True)
print(f"Charts will be saved in: {charts_output_dir}")
chart_filepaths = {}
df = state["processed_df"]
feedback_dist = state["feedback_distribution"]
instructor_rating_dist = state["instructor_rating_distribution"]
avg_scores_data = state["average_scores"]
# Define base filenames (keys will be used in template)
chart_defs = {
"avg_scores_subject": lambda: plot_avg_scores_per_subject(avg_scores_data['avg_scores_subject'], charts_output_dir, "avg_scores_subject"),
"avg_scores_dept": lambda: plot_avg_scores_per_department(avg_scores_data['avg_scores_dept'], charts_output_dir, "avg_scores_dept"),
"corr_heatmap_overall": lambda: plot_correlation_heatmap(df, charts_output_dir, "corr_heatmap_overall", "Overall"),
}
# Dynamically add subject/dept charts
for subject in df['Subject'].unique():
safe_subj = re.sub(r'\W+', '', subject.lower().replace(" ", "_")) # Sanitize filename
chart_defs[f"feedback_dist_{safe_subj}"] = lambda s=subject, fn=f"feedback_dist_{safe_subj}": plot_feedback_distribution_per_subject(feedback_dist, s, charts_output_dir, fn)
chart_defs[f"instructor_rating_dist_{safe_subj}"] = lambda s=subject, fn=f"instructor_rating_dist_{safe_subj}": plot_instructor_rating_distribution_per_subject(instructor_rating_dist, s, charts_output_dir, fn)
for dept in df['Department'].unique():
safe_dept = re.sub(r'\W+', '', dept.lower().replace(" ", "_"))
chart_defs[f"radar_chart_{safe_dept}"] = lambda d=dept, fn=f"radar_chart_{safe_dept}": plot_radar_chart_subject_department(avg_scores_data['avg_scores_subject_dept'], d, charts_output_dir, fn)
# Generate charts and collect filepaths
for key, func in chart_defs.items():
try:
filepath = func()
if filepath:
chart_filepaths[key] = filepath
except Exception as e:
print(f"Chart Gen Node: Error generating chart '{key}': {e}")
return {**state, "chart_filepaths": chart_filepaths, "charts_output_dir": charts_output_dir}
# analyze_data_node: No change needed
def generate_analysis_text_node(state: WorkflowState) -> WorkflowState: # Renamed node
print("---NODE: GENERATE ANALYSIS TEXT (LLM)---")
if not all(k in state for k in ["average_scores", "feedback_distribution", "instructor_rating_distribution", "correlations"]):
return {**state, "error_message": "Analysis text generation skipped: Missing data components."}
# Prepare data for LLM (ensure it doesn't contain DataFrames directly if not needed by LLM)
processed_data_for_llm = {
"avg_scores_subject": state["average_scores"]["avg_scores_subject"], # Pass DF for markdown conversion inside LLM function
"avg_scores_dept": state["average_scores"]["avg_scores_dept"], # Pass DF for markdown conversion inside LLM function
"feedback_distribution": state["feedback_distribution"],
"instructor_rating_distribution": state["instructor_rating_distribution"]
}
analysis_text = generate_analysis_text( # Call the NEW function
processed_data=processed_data_for_llm,
correlations=state["correlations"],
)
if analysis_text.startswith("Error generating analysis:"):
return {**state, "analysis_text": "", "error_message": analysis_text} # Propagate error
else:
return {**state, "analysis_text": analysis_text}
def assemble_html_report_node(state: WorkflowState) -> WorkflowState:
"""Assembles the final HTML report with base64-encoded images from LLM text and chart paths."""
print("---NODE: ASSEMBLE HTML REPORT (with Base64 Images)---")
if not state.get("analysis_text") or not state.get("chart_filepaths"):
return {**state, "error_message": "HTML assembly skipped: Missing analysis text or chart paths."}
# Determine output paths
timestamp_suffix = os.path.basename(state.get("charts_output_dir", f"run_{time.strftime('%Y%m%d_%H%M%S')}")).replace('charts_', '')
report_filename = f"evaluation_report_{timestamp_suffix}.html"
report_filepath = os.path.join("outputs", report_filename)
os.makedirs("outputs", exist_ok=True) # Ensure main 'outputs' directory exists
# --- 1. Parse LLM Text (Same as before) ---
analysis_sections = {}
current_section_key = "preamble" # Default key
current_section_title = "Introduction" # Default title
current_content = []
lines = state["analysis_text"].replace('\r\n', '\n').split('\n')
for line in lines:
match = re.match(r'^##\s+(.+)', line.strip()) # Matches "## Heading Title"
if match:
if current_content: # Save previous section
analysis_sections[current_section_key] = {
"title": current_section_title,
"content": "\n".join(current_content).strip()
}
current_section_title = match.group(1).strip()
current_section_key = current_section_title.lower().replace(" ", "_").replace("&", "and")
current_content = []
else:
if line.strip():
current_content.append(f"<p>{line.strip()}</p>") # Wrap non-heading lines in <p>
if current_content: # Save the last section
analysis_sections[current_section_key] = {
"title": current_section_title,
"content": "\n".join(current_content).strip()
}
if not analysis_sections and state["analysis_text"]: # Fallback if no "##" headings
analysis_sections["full_analysis"] = {"title": "Full Analysis", "content": f"<p>{state['analysis_text']}</p>"}
# --- 2. Load chart images and convert to base64 ---
chart_base64_data = {}
for key, full_chart_path_str in state["chart_filepaths"].items():
try:
# Read the image file and convert to base64
with open(full_chart_path_str, "rb") as img_file:
img_data = img_file.read()
# Convert binary data to base64 string and make it ready for HTML embedding
base64_encoded = base64.b64encode(img_data).decode('utf-8')
# Store the base64 data with appropriate data URI prefix
chart_base64_data[key] = f"data:image/png;base64,{base64_encoded}"
except Exception as e:
print(f"Warning: Failed to convert image to base64 for {key} ({full_chart_path_str}): {e}")
# Skip this image - it won't be included in the report
# --- 3. Build HTML Content ---
html_parts = []
# --- HTML Head and Styles ---
html_parts.append(f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Training Evaluation Report - {timestamp_suffix}</title>
<style>
body {{ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; margin: 0; padding: 0; line-height: 1.6; color: #333; background-color: #f4f7f6; }}
.report-container {{ max-width: 1000px; margin: 20px auto; background-color: #fff; padding: 25px 40px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.05); }}
h1 {{ text-align: center; color: #2c3e50; border-bottom: 3px solid #3498db; padding-bottom: 0.5em; margin-bottom: 1em; font-size: 2.2em; }}
h2 {{ color: #2c3e50; border-bottom: 2px solid #e0e0e0; padding-bottom: 0.4em; margin-top: 2em; margin-bottom: 1em; font-size: 1.8em; }}
h3 {{ color: #34495e; font-size: 1.4em; margin-top: 1.5em; margin-bottom: 0.8em; }}
p {{ margin-bottom: 1em; text-align: justify; }}
.text-section {{ margin-bottom: 30px; }}
.chart-card {{
border: 1px solid #ddd; border-radius: 6px; padding: 15px; margin-bottom: 20px;
background-color: #ffffff; box-shadow: 0 1px 3px rgba(0,0,0,0.04);
page-break-inside: avoid; display: flex; flex-direction: column; align-items: center;
}}
.chart-card img {{ max-width: 95%; height: auto; display: block; margin-top: 10px; border: 1px solid #eee; }}
.chart-title-text {{ font-weight: bold; margin-bottom: 10px; color: #555; font-size: 1.1em; text-align: center;}}
.chart-tags-section {{ margin-top: 40px; padding-top: 20px; border-top: 2px solid #3498db; }}
.chart-tags-container {{ display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; }}
.tag-group {{ margin-bottom:20px; width: 100%;}}
.tag-group h3 {{ text-align:center; }}
.executive-summary p, .key_insights_and_recommendations p {{ font-size: 1.05em; }}
</style>
</head>
<body>
<div class="report-container">
<h1>Training Evaluation Report</h1>
""")
# --- Add LLM Analysis Sections ---
# Define the order and specific chart integration
section_order_and_config = {
"executive_summary": {"title": "Executive Summary"},
"overall_performance_analysis": {
"title": "Overall Performance Analysis",
"charts": ["avg_scores_subject", "avg_scores_dept"]
},
"subject_performance_analysis": {"title": "Subject Performance Analysis"}, # Will list subject charts under 'All Charts'
"department_performance_analysis": {"title": "Department Performance Analysis"}, # Will list dept charts under 'All Charts'
"correlation_insights": {
"title": "Correlation Insights",
"charts": ["corr_heatmap_overall"]
},
"key_insights_and_recommendations": {"title": "Key Insights and Recommendations"}
}
# Chart titles for display (more user-friendly than keys)
chart_display_titles = {
'avg_scores_subject': "Average Scores per Subject",
'avg_scores_dept': "Average Scores per Department",
'corr_heatmap_overall': "Overall Correlation: Feedback vs. Instructor Rating"
# Dynamic titles for subject/dept charts will be generated
}
for section_key, config in section_order_and_config.items():
section_data = analysis_sections.get(section_key)
if section_data:
html_parts.append(f"<div class='text-section {section_key.replace('_', '-')}'>")
html_parts.append(f"<h2>{section_data.get('title', config.get('title', 'Analysis Section'))}</h2>")
html_parts.append(section_data.get('content', "<p>No analysis provided for this section.</p>"))
# Embed specific charts for this section if configured
if "charts" in config:
html_parts.append("<div class='chart-tags-container'>") # Use flex container
for chart_key in config["charts"]:
if chart_key in chart_base64_data:
display_title = chart_display_titles.get(chart_key, chart_key.replace("_", " ").title())
html_parts.append("<div class='chart-card' style='flex: 1 1 45%;'>") # Flex basis
html_parts.append(f"<div class='chart-title-text'>{display_title}</div>")
html_parts.append(f"<img src='{chart_base64_data[chart_key]}' alt='{display_title}'>")
html_parts.append("</div>")
html_parts.append("</div>")
html_parts.append("</div>")
# --- Add "All Charts" Section with Tags/Categories ---
html_parts.append("<div class='chart-tags-section'>")
html_parts.append("<h2>All Generated Charts</h2>")
# Categorize charts
categorized_charts = {
"Overview & Averages": [],
"Subject Distributions": [],
"Department Radars": [],
"Correlations": []
}
for key, base64_data in chart_base64_data.items():
display_title = chart_display_titles.get(key)
if not display_title: # Generate dynamic titles
if key.startswith("avg_scores_"): display_title = key.replace("avg_scores_", "Average Scores ").replace("_", " ").title()
elif key.startswith("feedback_dist_"): display_title = f"Feedback Distribution: {key.replace('feedback_dist_', '').replace('_', ' ').title()}"
elif key.startswith("instructor_rating_dist_"): display_title = f"Instructor Rating: {key.replace('instructor_rating_dist_', '').replace('_', ' ').title()}"
elif key.startswith("radar_chart_"): display_title = f"Radar: {key.replace('radar_chart_', '').replace('_', ' ').title()} Dept."
elif key.startswith("corr_heatmap_"): display_title = f"Correlation: {key.replace('corr_heatmap_', '').replace('_', ' ').title()}"
else: display_title = key.replace("_", " ").title()
chart_display_titles[key] = display_title # Store for consistency
# Assign to categories
if key.startswith("avg_scores_"): categorized_charts["Overview & Averages"].append((display_title, base64_data))
elif key.startswith("feedback_dist_") or key.startswith("instructor_rating_dist_"): categorized_charts["Subject Distributions"].append((display_title, base64_data))
elif key.startswith("radar_chart_"): categorized_charts["Department Radars"].append((display_title, base64_data))
elif key.startswith("corr_heatmap_"): categorized_charts["Correlations"].append((display_title, base64_data))
else: # Fallback category if needed
if "Other Charts" not in categorized_charts: categorized_charts["Other Charts"] = []
categorized_charts["Other Charts"].append((display_title, base64_data))
for category, charts_in_category in categorized_charts.items():
if charts_in_category:
html_parts.append(f"<div class='tag-group'><h3>{category}</h3><div class='chart-tags-container'>")
for title, base64_data in charts_in_category:
html_parts.append("<div class='chart-card' style='flex: 1 1 30%; min-width: 300px;'>") # 3 cards per row approx
html_parts.append(f"<div class='chart-title-text'>{title}</div>")
html_parts.append(f"<img src='{base64_data}' alt='{title}'>")
html_parts.append("</div>")
html_parts.append("</div></div>")
html_parts.append("</div>") # End chart-tags-section
# --- HTML Footer ---
html_parts.append("""
</div> <!-- end report-container -->
</body>
</html>
""")
# --- Join and Save ---
final_html_content = "\n".join(html_parts)
try:
with open(report_filepath, "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Assembled HTML report (with base64 images) saved to: {report_filepath}")
return {**state, "report_path": report_filepath, "error_message": None}
except Exception as e:
print(f"Error saving assembled HTML report: {e}")
return {**state, "report_path": None, "error_message": f"Failed to save HTML report: {e}"}
# --- Define Workflow Graph ---
def create_workflow():
workflow = StateGraph(WorkflowState)
workflow.add_node("load_data", load_data_node)
workflow.add_node("preprocess_data", preprocess_data_node)
workflow.add_node("generate_charts", generate_charts_node)
workflow.add_node("analyze_data", analyze_data_node)
workflow.add_node("generate_analysis_text", generate_analysis_text_node) # Renamed
workflow.add_node("assemble_html_report", assemble_html_report_node) # Added
workflow.set_entry_point("load_data")
workflow.add_edge("load_data", "preprocess_data")
workflow.add_edge("preprocess_data", "generate_charts")
workflow.add_edge("generate_charts", "analyze_data")
workflow.add_edge("analyze_data", "generate_analysis_text")
workflow.add_edge("generate_analysis_text", "assemble_html_report")
workflow.add_edge("assemble_html_report", END)
# Add conditional edges / error handling if needed later
app_graph = workflow.compile()
return app_graph