Spaces:
Running
Running
| """ | |
| Tests for DealFlow AI custom tools. | |
| Runs without LLM/CrewAI β tests tool logic in isolation. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| import pytest | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| # βββ PDFExtractorTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestPDFExtractorTool: | |
| def test_rejects_nonexistent_file(self): | |
| from src.tools.pdf_extractor import PDFExtractorTool | |
| tool = PDFExtractorTool() | |
| result = json.loads(tool._run("/nonexistent/file.pdf")) | |
| assert "error" in result | |
| def test_rejects_non_pdf(self, tmp_path): | |
| from src.tools.pdf_extractor import PDFExtractorTool | |
| txt_file = tmp_path / "deck.txt" | |
| txt_file.write_text("not a pdf") | |
| tool = PDFExtractorTool() | |
| result = json.loads(tool._run(str(txt_file))) | |
| assert "error" in result | |
| def test_extracts_real_pdf(self, tmp_path): | |
| """Create a minimal PDF using reportlab if available, else skip.""" | |
| try: | |
| from reportlab.pdfgen import canvas | |
| except ImportError: | |
| pytest.skip("reportlab not installed") | |
| pdf_path = tmp_path / "test_deck.pdf" | |
| c = canvas.Canvas(str(pdf_path)) | |
| c.drawString(100, 750, "Acme AI - Pitch Deck") | |
| c.drawString(100, 720, "Market Size: $10B TAM") | |
| c.save() | |
| from src.tools.pdf_extractor import PDFExtractorTool | |
| tool = PDFExtractorTool() | |
| result = json.loads(tool._run(str(pdf_path))) | |
| assert "total_pages" in result | |
| assert result["total_pages"] >= 1 | |
| assert "full_text" in result | |
| # βββ ChartGeneratorTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestChartGeneratorTool: | |
| def test_generates_bar_chart(self, tmp_path): | |
| from src.tools.chart_generator import ChartGeneratorTool | |
| tool = ChartGeneratorTool(output_dir=str(tmp_path)) | |
| input_data = json.dumps({ | |
| "chart_type": "bar", | |
| "title": "Test Revenue", | |
| "labels": ["Y1", "Y2", "Y3"], | |
| "values": [100000, 500000, 2000000], | |
| "output_dir": str(tmp_path), | |
| }) | |
| result = tool._run(input_data) | |
| assert "Chart saved to:" in result | |
| png_files = list(tmp_path.glob("*.png")) | |
| assert len(png_files) == 1 | |
| def test_generates_line_chart(self, tmp_path): | |
| from src.tools.chart_generator import ChartGeneratorTool | |
| tool = ChartGeneratorTool(output_dir=str(tmp_path)) | |
| input_data = json.dumps({ | |
| "chart_type": "line", | |
| "title": "MRR Growth", | |
| "labels": ["Jan", "Feb", "Mar", "Apr"], | |
| "values": [10000, 15000, 22000, 35000], | |
| "output_dir": str(tmp_path), | |
| }) | |
| result = tool._run(input_data) | |
| assert "Chart saved to:" in result | |
| def test_rejects_mismatched_lengths(self, tmp_path): | |
| from src.tools.chart_generator import ChartGeneratorTool | |
| tool = ChartGeneratorTool(output_dir=str(tmp_path)) | |
| input_data = json.dumps({ | |
| "chart_type": "bar", | |
| "title": "Bad", | |
| "labels": ["A", "B"], | |
| "values": [1, 2, 3], | |
| }) | |
| result = tool._run(input_data) | |
| assert "Error" in result | |
| def test_rejects_invalid_json(self, tmp_path): | |
| from src.tools.chart_generator import ChartGeneratorTool | |
| tool = ChartGeneratorTool(output_dir=str(tmp_path)) | |
| result = tool._run("not valid json") | |
| assert "Error" in result | |
| # βββ MemoWriterTool βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestMemoWriterTool: | |
| def test_saves_memo(self, tmp_path): | |
| from src.tools.memo_writer import MemoWriterTool | |
| tool = MemoWriterTool(output_dir=str(tmp_path)) | |
| memo_text = "# Investment Memo\n\n## Recommendation: INVEST\n" | |
| input_data = json.dumps({ | |
| "company_name": "Acme AI", | |
| "memo_content": memo_text, | |
| "output_dir": str(tmp_path), | |
| }) | |
| result = tool._run(input_data) | |
| assert "Investment memo saved to:" in result | |
| saved_files = list(tmp_path.glob("memo_*.md")) | |
| assert len(saved_files) == 1 | |
| content = saved_files[0].read_text() | |
| assert "Investment Memo" in content | |
| def test_rejects_empty_content(self, tmp_path): | |
| from src.tools.memo_writer import MemoWriterTool | |
| tool = MemoWriterTool(output_dir=str(tmp_path)) | |
| input_data = json.dumps({ | |
| "company_name": "Acme", | |
| "memo_content": "", | |
| }) | |
| result = tool._run(input_data) | |
| assert "Error" in result | |
| def test_rejects_invalid_json(self, tmp_path): | |
| from src.tools.memo_writer import MemoWriterTool | |
| tool = MemoWriterTool(output_dir=str(tmp_path)) | |
| result = tool._run("{bad json") | |
| assert "Error" in result | |