Spaces:
Running
Running
File size: 5,317 Bytes
8dcf472 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
Tests for DealFlow AI custom tools.
Runs without LLM/CrewAI β tests tool logic in isolation.
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
# βββ PDFExtractorTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestPDFExtractorTool:
def test_rejects_nonexistent_file(self):
from src.tools.pdf_extractor import PDFExtractorTool
tool = PDFExtractorTool()
result = json.loads(tool._run("/nonexistent/file.pdf"))
assert "error" in result
def test_rejects_non_pdf(self, tmp_path):
from src.tools.pdf_extractor import PDFExtractorTool
txt_file = tmp_path / "deck.txt"
txt_file.write_text("not a pdf")
tool = PDFExtractorTool()
result = json.loads(tool._run(str(txt_file)))
assert "error" in result
def test_extracts_real_pdf(self, tmp_path):
"""Create a minimal PDF using reportlab if available, else skip."""
try:
from reportlab.pdfgen import canvas
except ImportError:
pytest.skip("reportlab not installed")
pdf_path = tmp_path / "test_deck.pdf"
c = canvas.Canvas(str(pdf_path))
c.drawString(100, 750, "Acme AI - Pitch Deck")
c.drawString(100, 720, "Market Size: $10B TAM")
c.save()
from src.tools.pdf_extractor import PDFExtractorTool
tool = PDFExtractorTool()
result = json.loads(tool._run(str(pdf_path)))
assert "total_pages" in result
assert result["total_pages"] >= 1
assert "full_text" in result
# βββ ChartGeneratorTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestChartGeneratorTool:
def test_generates_bar_chart(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "bar",
"title": "Test Revenue",
"labels": ["Y1", "Y2", "Y3"],
"values": [100000, 500000, 2000000],
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Chart saved to:" in result
png_files = list(tmp_path.glob("*.png"))
assert len(png_files) == 1
def test_generates_line_chart(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "line",
"title": "MRR Growth",
"labels": ["Jan", "Feb", "Mar", "Apr"],
"values": [10000, 15000, 22000, 35000],
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Chart saved to:" in result
def test_rejects_mismatched_lengths(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "bar",
"title": "Bad",
"labels": ["A", "B"],
"values": [1, 2, 3],
})
result = tool._run(input_data)
assert "Error" in result
def test_rejects_invalid_json(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
result = tool._run("not valid json")
assert "Error" in result
# βββ MemoWriterTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class TestMemoWriterTool:
def test_saves_memo(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
memo_text = "# Investment Memo\n\n## Recommendation: INVEST\n"
input_data = json.dumps({
"company_name": "Acme AI",
"memo_content": memo_text,
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Investment memo saved to:" in result
saved_files = list(tmp_path.glob("memo_*.md"))
assert len(saved_files) == 1
content = saved_files[0].read_text()
assert "Investment Memo" in content
def test_rejects_empty_content(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
input_data = json.dumps({
"company_name": "Acme",
"memo_content": "",
})
result = tool._run(input_data)
assert "Error" in result
def test_rejects_invalid_json(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
result = tool._run("{bad json")
assert "Error" in result
|