File size: 5,317 Bytes
8dcf472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Tests for DealFlow AI custom tools.
Runs without LLM/CrewAI β€” tests tool logic in isolation.
"""
from __future__ import annotations

import json
import os
import sys
import tempfile
from pathlib import Path

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent))


# ─── PDFExtractorTool ─────────────────────────────────────────────────────────

class TestPDFExtractorTool:
    def test_rejects_nonexistent_file(self):
        from src.tools.pdf_extractor import PDFExtractorTool
        tool = PDFExtractorTool()
        result = json.loads(tool._run("/nonexistent/file.pdf"))
        assert "error" in result

    def test_rejects_non_pdf(self, tmp_path):
        from src.tools.pdf_extractor import PDFExtractorTool
        txt_file = tmp_path / "deck.txt"
        txt_file.write_text("not a pdf")
        tool = PDFExtractorTool()
        result = json.loads(tool._run(str(txt_file)))
        assert "error" in result

    def test_extracts_real_pdf(self, tmp_path):
        """Create a minimal PDF using reportlab if available, else skip."""
        try:
            from reportlab.pdfgen import canvas
        except ImportError:
            pytest.skip("reportlab not installed")

        pdf_path = tmp_path / "test_deck.pdf"
        c = canvas.Canvas(str(pdf_path))
        c.drawString(100, 750, "Acme AI - Pitch Deck")
        c.drawString(100, 720, "Market Size: $10B TAM")
        c.save()

        from src.tools.pdf_extractor import PDFExtractorTool
        tool = PDFExtractorTool()
        result = json.loads(tool._run(str(pdf_path)))
        assert "total_pages" in result
        assert result["total_pages"] >= 1
        assert "full_text" in result


# ─── ChartGeneratorTool ───────────────────────────────────────────────────────

class TestChartGeneratorTool:
    def test_generates_bar_chart(self, tmp_path):
        from src.tools.chart_generator import ChartGeneratorTool
        tool = ChartGeneratorTool(output_dir=str(tmp_path))
        input_data = json.dumps({
            "chart_type": "bar",
            "title": "Test Revenue",
            "labels": ["Y1", "Y2", "Y3"],
            "values": [100000, 500000, 2000000],
            "output_dir": str(tmp_path),
        })
        result = tool._run(input_data)
        assert "Chart saved to:" in result
        png_files = list(tmp_path.glob("*.png"))
        assert len(png_files) == 1

    def test_generates_line_chart(self, tmp_path):
        from src.tools.chart_generator import ChartGeneratorTool
        tool = ChartGeneratorTool(output_dir=str(tmp_path))
        input_data = json.dumps({
            "chart_type": "line",
            "title": "MRR Growth",
            "labels": ["Jan", "Feb", "Mar", "Apr"],
            "values": [10000, 15000, 22000, 35000],
            "output_dir": str(tmp_path),
        })
        result = tool._run(input_data)
        assert "Chart saved to:" in result

    def test_rejects_mismatched_lengths(self, tmp_path):
        from src.tools.chart_generator import ChartGeneratorTool
        tool = ChartGeneratorTool(output_dir=str(tmp_path))
        input_data = json.dumps({
            "chart_type": "bar",
            "title": "Bad",
            "labels": ["A", "B"],
            "values": [1, 2, 3],
        })
        result = tool._run(input_data)
        assert "Error" in result

    def test_rejects_invalid_json(self, tmp_path):
        from src.tools.chart_generator import ChartGeneratorTool
        tool = ChartGeneratorTool(output_dir=str(tmp_path))
        result = tool._run("not valid json")
        assert "Error" in result


# ─── MemoWriterTool ───────────────────────────────────────────────────────────

class TestMemoWriterTool:
    def test_saves_memo(self, tmp_path):
        from src.tools.memo_writer import MemoWriterTool
        tool = MemoWriterTool(output_dir=str(tmp_path))
        memo_text = "# Investment Memo\n\n## Recommendation: INVEST\n"
        input_data = json.dumps({
            "company_name": "Acme AI",
            "memo_content": memo_text,
            "output_dir": str(tmp_path),
        })
        result = tool._run(input_data)
        assert "Investment memo saved to:" in result

        saved_files = list(tmp_path.glob("memo_*.md"))
        assert len(saved_files) == 1
        content = saved_files[0].read_text()
        assert "Investment Memo" in content

    def test_rejects_empty_content(self, tmp_path):
        from src.tools.memo_writer import MemoWriterTool
        tool = MemoWriterTool(output_dir=str(tmp_path))
        input_data = json.dumps({
            "company_name": "Acme",
            "memo_content": "",
        })
        result = tool._run(input_data)
        assert "Error" in result

    def test_rejects_invalid_json(self, tmp_path):
        from src.tools.memo_writer import MemoWriterTool
        tool = MemoWriterTool(output_dir=str(tmp_path))
        result = tool._run("{bad json")
        assert "Error" in result