dealflow-ai / tests /test_tools.py
PeterBot22's picture
feat: DealFlow AI MVP β€” 3-agent CrewAI due diligence system on HF Spaces
8dcf472 verified
"""
Tests for DealFlow AI custom tools.
Runs without LLM/CrewAI β€” tests tool logic in isolation.
"""
from __future__ import annotations
import json
import os
import sys
import tempfile
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
# ─── PDFExtractorTool ─────────────────────────────────────────────────────────
class TestPDFExtractorTool:
def test_rejects_nonexistent_file(self):
from src.tools.pdf_extractor import PDFExtractorTool
tool = PDFExtractorTool()
result = json.loads(tool._run("/nonexistent/file.pdf"))
assert "error" in result
def test_rejects_non_pdf(self, tmp_path):
from src.tools.pdf_extractor import PDFExtractorTool
txt_file = tmp_path / "deck.txt"
txt_file.write_text("not a pdf")
tool = PDFExtractorTool()
result = json.loads(tool._run(str(txt_file)))
assert "error" in result
def test_extracts_real_pdf(self, tmp_path):
"""Create a minimal PDF using reportlab if available, else skip."""
try:
from reportlab.pdfgen import canvas
except ImportError:
pytest.skip("reportlab not installed")
pdf_path = tmp_path / "test_deck.pdf"
c = canvas.Canvas(str(pdf_path))
c.drawString(100, 750, "Acme AI - Pitch Deck")
c.drawString(100, 720, "Market Size: $10B TAM")
c.save()
from src.tools.pdf_extractor import PDFExtractorTool
tool = PDFExtractorTool()
result = json.loads(tool._run(str(pdf_path)))
assert "total_pages" in result
assert result["total_pages"] >= 1
assert "full_text" in result
# ─── ChartGeneratorTool ───────────────────────────────────────────────────────
class TestChartGeneratorTool:
def test_generates_bar_chart(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "bar",
"title": "Test Revenue",
"labels": ["Y1", "Y2", "Y3"],
"values": [100000, 500000, 2000000],
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Chart saved to:" in result
png_files = list(tmp_path.glob("*.png"))
assert len(png_files) == 1
def test_generates_line_chart(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "line",
"title": "MRR Growth",
"labels": ["Jan", "Feb", "Mar", "Apr"],
"values": [10000, 15000, 22000, 35000],
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Chart saved to:" in result
def test_rejects_mismatched_lengths(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
input_data = json.dumps({
"chart_type": "bar",
"title": "Bad",
"labels": ["A", "B"],
"values": [1, 2, 3],
})
result = tool._run(input_data)
assert "Error" in result
def test_rejects_invalid_json(self, tmp_path):
from src.tools.chart_generator import ChartGeneratorTool
tool = ChartGeneratorTool(output_dir=str(tmp_path))
result = tool._run("not valid json")
assert "Error" in result
# ─── MemoWriterTool ───────────────────────────────────────────────────────────
class TestMemoWriterTool:
def test_saves_memo(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
memo_text = "# Investment Memo\n\n## Recommendation: INVEST\n"
input_data = json.dumps({
"company_name": "Acme AI",
"memo_content": memo_text,
"output_dir": str(tmp_path),
})
result = tool._run(input_data)
assert "Investment memo saved to:" in result
saved_files = list(tmp_path.glob("memo_*.md"))
assert len(saved_files) == 1
content = saved_files[0].read_text()
assert "Investment Memo" in content
def test_rejects_empty_content(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
input_data = json.dumps({
"company_name": "Acme",
"memo_content": "",
})
result = tool._run(input_data)
assert "Error" in result
def test_rejects_invalid_json(self, tmp_path):
from src.tools.memo_writer import MemoWriterTool
tool = MemoWriterTool(output_dir=str(tmp_path))
result = tool._run("{bad json")
assert "Error" in result