""" Tests for DealFlow AI custom tools. Runs without LLM/CrewAI — tests tool logic in isolation. """ from __future__ import annotations import json import os import sys import tempfile from pathlib import Path import pytest sys.path.insert(0, str(Path(__file__).parent.parent)) # ─── PDFExtractorTool ───────────────────────────────────────────────────────── class TestPDFExtractorTool: def test_rejects_nonexistent_file(self): from src.tools.pdf_extractor import PDFExtractorTool tool = PDFExtractorTool() result = json.loads(tool._run("/nonexistent/file.pdf")) assert "error" in result def test_rejects_non_pdf(self, tmp_path): from src.tools.pdf_extractor import PDFExtractorTool txt_file = tmp_path / "deck.txt" txt_file.write_text("not a pdf") tool = PDFExtractorTool() result = json.loads(tool._run(str(txt_file))) assert "error" in result def test_extracts_real_pdf(self, tmp_path): """Create a minimal PDF using reportlab if available, else skip.""" try: from reportlab.pdfgen import canvas except ImportError: pytest.skip("reportlab not installed") pdf_path = tmp_path / "test_deck.pdf" c = canvas.Canvas(str(pdf_path)) c.drawString(100, 750, "Acme AI - Pitch Deck") c.drawString(100, 720, "Market Size: $10B TAM") c.save() from src.tools.pdf_extractor import PDFExtractorTool tool = PDFExtractorTool() result = json.loads(tool._run(str(pdf_path))) assert "total_pages" in result assert result["total_pages"] >= 1 assert "full_text" in result # ─── ChartGeneratorTool ─────────────────────────────────────────────────────── class TestChartGeneratorTool: def test_generates_bar_chart(self, tmp_path): from src.tools.chart_generator import ChartGeneratorTool tool = ChartGeneratorTool(output_dir=str(tmp_path)) input_data = json.dumps({ "chart_type": "bar", "title": "Test Revenue", "labels": ["Y1", "Y2", "Y3"], "values": [100000, 500000, 2000000], "output_dir": str(tmp_path), }) result = tool._run(input_data) assert "Chart saved to:" in result png_files = list(tmp_path.glob("*.png")) assert len(png_files) == 1 def test_generates_line_chart(self, tmp_path): from src.tools.chart_generator import ChartGeneratorTool tool = ChartGeneratorTool(output_dir=str(tmp_path)) input_data = json.dumps({ "chart_type": "line", "title": "MRR Growth", "labels": ["Jan", "Feb", "Mar", "Apr"], "values": [10000, 15000, 22000, 35000], "output_dir": str(tmp_path), }) result = tool._run(input_data) assert "Chart saved to:" in result def test_rejects_mismatched_lengths(self, tmp_path): from src.tools.chart_generator import ChartGeneratorTool tool = ChartGeneratorTool(output_dir=str(tmp_path)) input_data = json.dumps({ "chart_type": "bar", "title": "Bad", "labels": ["A", "B"], "values": [1, 2, 3], }) result = tool._run(input_data) assert "Error" in result def test_rejects_invalid_json(self, tmp_path): from src.tools.chart_generator import ChartGeneratorTool tool = ChartGeneratorTool(output_dir=str(tmp_path)) result = tool._run("not valid json") assert "Error" in result # ─── MemoWriterTool ─────────────────────────────────────────────────────────── class TestMemoWriterTool: def test_saves_memo(self, tmp_path): from src.tools.memo_writer import MemoWriterTool tool = MemoWriterTool(output_dir=str(tmp_path)) memo_text = "# Investment Memo\n\n## Recommendation: INVEST\n" input_data = json.dumps({ "company_name": "Acme AI", "memo_content": memo_text, "output_dir": str(tmp_path), }) result = tool._run(input_data) assert "Investment memo saved to:" in result saved_files = list(tmp_path.glob("memo_*.md")) assert len(saved_files) == 1 content = saved_files[0].read_text() assert "Investment Memo" in content def test_rejects_empty_content(self, tmp_path): from src.tools.memo_writer import MemoWriterTool tool = MemoWriterTool(output_dir=str(tmp_path)) input_data = json.dumps({ "company_name": "Acme", "memo_content": "", }) result = tool._run(input_data) assert "Error" in result def test_rejects_invalid_json(self, tmp_path): from src.tools.memo_writer import MemoWriterTool tool = MemoWriterTool(output_dir=str(tmp_path)) result = tool._run("{bad json") assert "Error" in result