research-lens / tests /test_api.py
thundarstrom's picture
fix: resolve all bugs and add comprehensive test suite
1d293d8
"""
test_api.py
===========
End-to-end API tests for every ResearchLens endpoint using FastAPI TestClient.
Creates synthetic PDFs, uploads them, and exercises all features.
"""
import os
import shutil
import fitz
import pytest
from fastapi.testclient import TestClient
# ─── Fixtures ────────────────────────────────────────────────────────────────
@pytest.fixture(scope="module")
def client():
"""Create a TestClient for the FastAPI app."""
from src.server import app, GLOBAL_STATE
# Start with a clean state
GLOBAL_STATE["unified_indices"].clear()
GLOBAL_STATE["paper_results"].clear()
with TestClient(app) as c:
yield c
# Cleanup after all tests
GLOBAL_STATE["unified_indices"].clear()
GLOBAL_STATE["paper_results"].clear()
@pytest.fixture(scope="module")
def synthetic_pdfs():
"""Create 2 synthetic academic PDFs for testing."""
papers = [
{
"filename": "test_api_paper_1.pdf",
"title": "Deep Learning for Image Classification",
"author": "Alice Researcher",
"content": (
"Abstract\n"
"This paper proposes a novel convolutional neural network architecture for image classification. "
"We achieve 95.2% accuracy on CIFAR-10 using a residual attention mechanism. "
"Our method outperforms existing baselines by 3.1 percentage points.\n\n"
"1. Introduction\n"
"Image classification is a fundamental task in computer vision. "
"Recent advances in deep learning have significantly improved accuracy. "
"However, most methods require extensive computational resources.\n\n"
"2. Methodology\n"
"We propose ResAttNet, which combines residual connections with self-attention layers. "
"The model uses 12 convolutional blocks followed by 4 attention heads. "
"Training uses AdamW optimizer with cosine annealing schedule.\n\n"
"3. Results\n"
"On CIFAR-10, our model achieves 95.2% test accuracy. "
"On ImageNet, we reach 78.4% top-1 accuracy. "
"The model trains in 8 hours on a single A100 GPU.\n\n"
"4. Conclusion\n"
"ResAttNet demonstrates that attention mechanisms can improve CNN performance. "
"Future work includes extending to video classification tasks."
)
},
{
"filename": "test_api_paper_2.pdf",
"title": "Transformer Models for Natural Language Processing",
"author": "Bob Scientist",
"content": (
"Abstract\n"
"We present a comprehensive study of transformer architectures for NLP tasks. "
"Our fine-tuned model achieves state-of-the-art results on GLUE benchmark with 89.7% average score.\n\n"
"1. Introduction\n"
"Natural language processing has been revolutionized by transformer models. "
"Self-attention allows capturing long-range dependencies in text. "
"However, computational cost scales quadratically with sequence length.\n\n"
"2. Methodology\n"
"We fine-tune a 350M parameter transformer on GLUE tasks. "
"We use mixed-precision training and gradient accumulation. "
"The learning rate follows a linear warmup schedule.\n\n"
"3. Results\n"
"Our model achieves 89.7% average score on GLUE. "
"On SQuAD 2.0, we reach 88.3% F1 score. "
"Inference takes 15ms per query on V100.\n\n"
"4. Discussion\n"
"Transformers are computationally expensive but highly effective. "
"Our results confirm that scale matters for NLP performance.\n\n"
"5. Conclusion\n"
"We recommend transformer architectures for production NLP systems. "
"Future work includes efficient attention mechanisms to reduce compute costs."
)
}
]
filenames = []
for paper in papers:
doc = fitz.open()
# Create multiple pages with enough text for proper chunking
page = doc.new_page()
page.insert_text((50, 50), paper["title"], fontsize=18)
page.insert_text((50, 80), f"By {paper['author']}", fontsize=12)
y = 120
for line in paper["content"].split("\n"):
if y > 750:
page = doc.new_page()
y = 50
page.insert_text((50, y), line, fontsize=11)
y += 16
doc.set_metadata({"title": paper["title"], "author": paper["author"]})
doc.save(paper["filename"])
doc.close()
filenames.append(paper["filename"])
yield filenames
# Cleanup
for f in filenames:
if os.path.exists(f):
os.remove(f)
if os.path.exists("data/indices"):
shutil.rmtree("data/indices", ignore_errors=True)
# ─── Tests ───────────────────────────────────────────────────────────────────
class TestFrontendServing:
"""Test that the frontend is served correctly."""
def test_serve_frontend(self, client):
res = client.get("/")
assert res.status_code == 200
assert "ResearchLens" in res.text
class TestUpload:
"""Test PDF upload functionality."""
def test_upload_pdf(self, client, synthetic_pdfs):
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
res = client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
assert res.status_code == 200, f"Upload failed: {res.json()}"
data = res.json()
assert data["success"] is True
assert "paper_id" in data
assert "title" in data
assert len(data["title"]) > 0
def test_upload_non_pdf(self, client):
"""Should reject non-PDF files."""
res = client.post(
"/api/upload",
files={"file": ("test.txt", b"This is not a PDF", "text/plain")}
)
assert res.status_code == 400
assert "PDF" in res.json()["detail"]
class TestPapers:
"""Test paper listing."""
def test_list_papers(self, client, synthetic_pdfs):
# Ensure papers are uploaded first
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
res = client.get("/api/papers")
assert res.status_code == 200
data = res.json()
assert len(data["papers"]) >= 2
for paper in data["papers"]:
assert "paper_id" in paper
assert "title" in paper
assert "year" in paper
class TestChat:
"""Test the chat/QA functionality."""
def test_chat(self, client, synthetic_pdfs):
# Ensure papers are uploaded
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
res = client.post(
"/api/chat",
json={"query": "What accuracy was achieved on CIFAR-10?", "history": []}
)
assert res.status_code == 200
data = res.json()
assert "answer" in data
assert len(data["answer"]) > 10 # Should have a substantive answer
def test_chat_no_papers(self, client):
"""When no papers are loaded, should return a helpful message."""
from src.server import GLOBAL_STATE
# Temporarily clear state
saved_indices = dict(GLOBAL_STATE["unified_indices"])
saved_results = dict(GLOBAL_STATE["paper_results"])
GLOBAL_STATE["unified_indices"].clear()
GLOBAL_STATE["paper_results"].clear()
res = client.post(
"/api/chat",
json={"query": "What is the methodology?", "history": []}
)
assert res.status_code == 200
assert "upload" in res.json()["answer"].lower()
# Restore state
GLOBAL_STATE["unified_indices"].update(saved_indices)
GLOBAL_STATE["paper_results"].update(saved_results)
class TestSummarize:
"""Test paper summarization."""
def test_summarize(self, client, synthetic_pdfs):
# Ensure papers are uploaded and get the paper_id
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
papers_res = client.get("/api/papers")
paper_id = papers_res.json()["papers"][0]["paper_id"]
res = client.post("/api/summarize", json={"paper_id": paper_id})
assert res.status_code == 200
data = res.json()
# All 6 fields should be present
for field in ["title", "contribution", "methodology", "results", "datasets", "limitations"]:
assert field in data, f"Missing field: {field}"
assert len(str(data[field])) > 0
def test_summarize_not_found(self, client):
"""Should return 404 for an invalid paper_id."""
res = client.post("/api/summarize", json={"paper_id": "nonexistent_paper_id"})
assert res.status_code == 404
class TestIntelligence:
"""Test all cross-paper intelligence features."""
def _ensure_papers_loaded(self, client, synthetic_pdfs):
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
def test_compare(self, client, synthetic_pdfs):
self._ensure_papers_loaded(client, synthetic_pdfs)
res = client.post("/api/intelligence", json={"action": "compare"})
assert res.status_code == 200
data = res.json()
assert data["type"] == "table"
assert isinstance(data["data"], list)
if len(data["data"]) > 0:
assert "dimension" in data["data"][0]
assert "values" in data["data"][0]
def test_contradictions(self, client, synthetic_pdfs):
self._ensure_papers_loaded(client, synthetic_pdfs)
res = client.post("/api/intelligence", json={"action": "contradictions"})
assert res.status_code == 200
data = res.json()
assert data["type"] == "contradictions"
assert isinstance(data["data"], list)
def test_review(self, client, synthetic_pdfs):
self._ensure_papers_loaded(client, synthetic_pdfs)
res = client.post("/api/intelligence", json={"action": "review"})
assert res.status_code == 200
data = res.json()
assert data["type"] == "text"
assert len(data["data"]) > 50
def test_hypotheses(self, client, synthetic_pdfs):
self._ensure_papers_loaded(client, synthetic_pdfs)
res = client.post("/api/intelligence", json={"action": "hypotheses"})
assert res.status_code == 200
data = res.json()
assert data["type"] == "text"
assert len(data["data"]) > 50
def test_intelligence_no_papers(self, client):
"""Should return 400 when no papers are loaded."""
from src.server import GLOBAL_STATE
saved_indices = dict(GLOBAL_STATE["unified_indices"])
saved_results = dict(GLOBAL_STATE["paper_results"])
GLOBAL_STATE["unified_indices"].clear()
GLOBAL_STATE["paper_results"].clear()
res = client.post("/api/intelligence", json={"action": "compare"})
assert res.status_code == 400
GLOBAL_STATE["unified_indices"].update(saved_indices)
GLOBAL_STATE["paper_results"].update(saved_results)
def test_unknown_action(self, client, synthetic_pdfs):
self._ensure_papers_loaded(client, synthetic_pdfs)
res = client.post("/api/intelligence", json={"action": "invalid_action"})
assert res.status_code == 400
class TestClearMemory:
"""Test memory clearing."""
def test_clear_memory(self, client, synthetic_pdfs):
# Ensure papers are loaded first
for pdf_path in synthetic_pdfs:
with open(pdf_path, "rb") as f:
client.post("/api/upload", files={"file": (pdf_path, f, "application/pdf")})
# Clear
res = client.post("/api/clear")
assert res.status_code == 200
assert res.json()["success"] is True
# Verify papers list is now empty
papers_res = client.get("/api/papers")
assert len(papers_res.json()["papers"]) == 0