Spaces:
Running
Running
| import json | |
| from io import BytesIO | |
| import plotly.graph_objects as go | |
| from reportlab.lib import colors | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.lib.styles import getSampleStyleSheet | |
| from reportlab.platypus import Image, Paragraph, Preformatted, SimpleDocTemplate, Spacer, Table, TableStyle | |
| def _figure_to_image(figure, width=640, height=300): | |
| """Convert a Plotly figure to a reportlab Image. Returns None if conversion fails.""" | |
| try: | |
| image_bytes = figure.to_image(format="png", width=width, height=height, scale=2) | |
| return Image(BytesIO(image_bytes), width=width * 0.75, height=height * 0.75) | |
| except Exception: | |
| return None | |
| def _metrics_table(rows): | |
| table = Table(rows, hAlign="LEFT") | |
| table.setStyle( | |
| TableStyle( | |
| [ | |
| ("BACKGROUND", (0, 0), (-1, 0), colors.HexColor("#f3f4f6")), | |
| ("TEXTCOLOR", (0, 0), (-1, 0), colors.HexColor("#111827")), | |
| ("ALIGN", (0, 0), (-1, -1), "LEFT"), | |
| ("FONTNAME", (0, 0), (-1, 0), "Helvetica-Bold"), | |
| ("FONTSIZE", (0, 0), (-1, -1), 9), | |
| ("GRID", (0, 0), (-1, -1), 0.5, colors.HexColor("#d1d5db")), | |
| ("VALIGN", (0, 0), (-1, -1), "MIDDLE"), | |
| ] | |
| ) | |
| ) | |
| return table | |
| def _single_model_section(story, model_name, model_payload, styles): | |
| prediction = "HATE SPEECH DETECTED" if model_payload["prediction"] == 1 else "NOT HATE SPEECH" | |
| prediction_color = "#c62828" if model_payload["prediction"] == 1 else "#2e7d32" | |
| story.append(Paragraph(f"<b>{model_name}</b>", styles["Heading3"])) | |
| story.append(Paragraph(f"<font color='{prediction_color}'><b>{prediction}</b></font>", styles["Normal"])) | |
| story.append(Spacer(1, 6)) | |
| rows = [ | |
| ["Metric", "Value"], | |
| ["Confidence", f"{model_payload['confidence']:.1%}"], | |
| ["Not Hate Speech", f"{model_payload['probabilities'][0]:.1%}"], | |
| ["Hate Speech", f"{model_payload['probabilities'][1]:.1%}"], | |
| ["Processing Time", f"{model_payload['processing_time']:.3f}s"], | |
| ] | |
| story.append(_metrics_table(rows)) | |
| story.append(Spacer(1, 8)) | |
| probability_chart = go.Figure( | |
| data=[ | |
| go.Bar( | |
| x=["Not Hate Speech", "Hate Speech"], | |
| y=model_payload["probabilities"], | |
| marker_color=["#66bb6a", "#ef5350"], | |
| text=[f"{value:.1%}" for value in model_payload["probabilities"]], | |
| textposition="auto", | |
| ) | |
| ] | |
| ) | |
| probability_chart.update_layout(yaxis_range=[0, 1], height=300, showlegend=False, margin=dict(l=20, r=20, t=20, b=20)) | |
| chart_image = _figure_to_image(probability_chart) | |
| if chart_image: | |
| story.append(Paragraph("Probability Distribution", styles["Italic"])) | |
| story.append(chart_image) | |
| story.append(Spacer(1, 8)) | |
| token_rows = model_payload.get("token_importance") or [] | |
| if token_rows: | |
| story.append(Paragraph("Top Important Tokens", styles["Italic"])) | |
| token_table_rows = [["Token", "Importance"]] + [ | |
| [token_data["Token"], f"{token_data['Importance']:.4f}"] for token_data in token_rows | |
| ] | |
| story.append(_metrics_table(token_table_rows)) | |
| story.append(Spacer(1, 8)) | |
| technical_details = model_payload.get("technical_details") | |
| if technical_details: | |
| story.append(Paragraph("Technical Details", styles["Italic"])) | |
| story.append(Preformatted(json.dumps(technical_details, indent=2), styles["Code"])) | |
| story.append(Spacer(1, 12)) | |
| def _batch_model_section(story, model_name, model_payload, styles, colorscale): | |
| story.append(Paragraph(f"<b>{model_name}</b>", styles["Heading3"])) | |
| story.append(Spacer(1, 4)) | |
| rows = [ | |
| ["Metric", "Value"], | |
| ["F1 Score", f"{model_payload['f1_score']:.4f}"], | |
| ["Precision", f"{model_payload['precision']:.4f}"], | |
| ["Accuracy", f"{model_payload['accuracy']:.4f}"], | |
| ["Recall", f"{model_payload['recall']:.4f}"], | |
| ["Avg CPU", f"{model_payload['cpu_usage']:.2f}%"], | |
| ["Peak CPU", f"{model_payload['peak_cpu_usage']:.2f}%"], | |
| ["Avg Memory", f"{model_payload['memory_usage']:.2f} MB"], | |
| ["Peak Memory", f"{model_payload['peak_memory_usage']:.2f} MB"], | |
| ["Total Runtime", f"{model_payload['runtime']:.2f}s"], | |
| ["Avg Time/Sample", f"{model_payload['avg_time_per_sample']:.3f}s"], | |
| ] | |
| story.append(_metrics_table(rows)) | |
| story.append(Spacer(1, 8)) | |
| confusion_matrix = model_payload.get("confusion_matrix") | |
| if confusion_matrix is not None: | |
| confusion_fig = go.Figure( | |
| data=go.Heatmap( | |
| z=confusion_matrix, | |
| x=["Pred Not Hate", "Pred Hate"], | |
| y=["True Not Hate", "True Hate"], | |
| colorscale=colorscale, | |
| text=confusion_matrix, | |
| texttemplate="%{text}", | |
| textfont={"size": 14}, | |
| showscale=False, | |
| ) | |
| ) | |
| confusion_fig.update_layout(height=300, margin=dict(l=20, r=20, t=20, b=20)) | |
| confusion_image = _figure_to_image(confusion_fig) | |
| if confusion_image: | |
| story.append(Paragraph("Confusion Matrix", styles["Italic"])) | |
| story.append(confusion_image) | |
| story.append(Spacer(1, 12)) | |
| def generate_results_pdf(payload): | |
| """Generate a PDF from the current Streamlit results payload.""" | |
| output_buffer = BytesIO() | |
| document = SimpleDocTemplate( | |
| output_buffer, | |
| pagesize=letter, | |
| leftMargin=36, | |
| rightMargin=36, | |
| topMargin=36, | |
| bottomMargin=36, | |
| title="Hate Speech Detection Results", | |
| ) | |
| styles = getSampleStyleSheet() | |
| story = [ | |
| Paragraph("Hate Speech Detection Results", styles["Title"]), | |
| Spacer(1, 8), | |
| ] | |
| mode = payload.get("mode", "single") | |
| if mode == "single": | |
| story.append(Paragraph("Single Text Analysis", styles["Heading2"])) | |
| story.append(Spacer(1, 6)) | |
| story.append(Paragraph(f"<b>Input Text:</b> {payload.get('input_text', '')}", styles["Normal"])) | |
| rationale_text = payload.get("rationale") | |
| if rationale_text: | |
| story.append(Spacer(1, 4)) | |
| story.append(Paragraph(f"<b>Rationale:</b> {rationale_text}", styles["Normal"])) | |
| story.append(Spacer(1, 10)) | |
| models = payload.get("models", {}) | |
| if "base" in models: | |
| _single_model_section(story, "Base Bert Ensemble Results", models["base"], styles) | |
| if "enhanced" in models: | |
| _single_model_section(story, "Enhanced Bert Ensemble Results", models["enhanced"], styles) | |
| elif mode == "batch": | |
| story.append(Paragraph("Batch File Analysis", styles["Heading2"])) | |
| story.append(Spacer(1, 6)) | |
| story.append(Paragraph(f"<b>Filename:</b> {payload.get('filename', 'Unknown')}", styles["Normal"])) | |
| story.append(Paragraph(f"<b>Rows:</b> {payload.get('rows', 0)}", styles["Normal"])) | |
| story.append(Spacer(1, 10)) | |
| models = payload.get("models", {}) | |
| if "base" in models: | |
| _batch_model_section(story, "Base Bert Ensemble Results", models["base"], styles, "Blues") | |
| if "enhanced" in models: | |
| _batch_model_section(story, "Enhanced Bert Ensemble Results", models["enhanced"], styles, "Greens") | |
| else: | |
| story.append(Paragraph("No exportable results found.", styles["Normal"])) | |
| document.build(story) | |
| return output_buffer.getvalue() | |