Spaces:
Running
Running
File size: 6,679 Bytes
a5f93e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """
Mission 29: Knowledge Chunks Tests
Tests file upload creates chunks, cascade delete, and chunk retrieval.
"""
import io
import pytest
import pytest_asyncio
from typing import Optional
from uuid import UUID, uuid4
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from app.models.models import (
Workspace, WorkspaceMember, SystemModuleConfig, KnowledgeChunk,
)
from app.core.modules import module_cache
from app.services.knowledge_chunker import chunk_text, retrieve_relevant_chunks
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
async def _signup_and_login(client: AsyncClient, email: str) -> str:
await client.post("/api/v1/auth/signup", json={
"email": email,
"password": "securepassword123",
"full_name": "KB User",
})
login = await client.post(
"/api/v1/auth/login",
data={"username": email, "password": "securepassword123"},
headers={"content-type": "application/x-www-form-urlencoded"},
)
return login.json()["data"]["access_token"]
def _auth(token: str, workspace_id: Optional[str] = None) -> dict:
headers = {"Authorization": f"Bearer {token}"}
if workspace_id:
headers["X-Workspace-ID"] = workspace_id
return headers
@pytest.fixture(autouse=True)
def clear_module_cache():
module_cache._cache.clear()
yield
module_cache._cache.clear()
@pytest_asyncio.fixture
async def kb_setup(async_client: AsyncClient, db_session: AsyncSession):
"""Create user + workspace + modules for knowledge tests."""
token = await _signup_and_login(async_client, "kbchunk@example.com")
ws = Workspace(id=uuid4(), name="KB Chunk WS", subscription_tier="free")
db_session.add(ws)
me_res = await async_client.get("/api/v1/auth/me", headers=_auth(token))
user_id = UUID(me_res.json()["data"]["id"])
member = WorkspaceMember(user_id=user_id, workspace_id=ws.id, role="owner")
db_session.add(member)
for mod in ["knowledge_files", "prompt_studio"]:
db_session.add(SystemModuleConfig(module_name=mod, is_enabled=True))
await db_session.flush()
return {"token": token, "workspace": ws}
# ββ Unit Tests: chunk_text βββββββββββββββββββββββββββββββββββββββββ
def test_chunk_text_short():
"""Short text should return a single chunk."""
result = chunk_text("Hello world")
assert len(result) == 1
assert result[0] == "Hello world"
def test_chunk_text_empty():
"""Empty text should return empty list."""
assert chunk_text("") == []
assert chunk_text(" ") == []
def test_chunk_text_long():
"""Long text should be split into multiple chunks."""
text = "A" * 3000
chunks = chunk_text(text)
assert len(chunks) > 1
# All original text should be covered
total = sum(len(c) for c in chunks)
# Due to overlap, total chars > original but all content is covered
assert total >= 3000
def test_chunk_text_paragraph_boundaries():
"""Chunks should prefer paragraph boundaries."""
text = ("First paragraph. " * 30) + "\n\n" + ("Second paragraph. " * 30)
chunks = chunk_text(text)
assert len(chunks) >= 2
# ββ Integration Tests ββββββββββββββββββββββββββββββββββββββββββββββ
@pytest.mark.asyncio
async def test_upload_creates_chunks(async_client: AsyncClient, kb_setup, db_session: AsyncSession):
setup = kb_setup
headers = _auth(setup["token"], str(setup["workspace"].id))
file_content = b"This is a test knowledge file with important pricing information.\n" * 20
response = await async_client.post(
"/api/v1/knowledge/files",
files={"file": ("test.txt", io.BytesIO(file_content), "text/plain")},
headers=headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["extracted"] is True
assert data["data"]["chunk_count"] > 0
# Verify chunks in DB
file_id = UUID(data["data"]["id"])
result = await db_session.execute(
select(KnowledgeChunk).where(KnowledgeChunk.knowledge_file_id == file_id)
)
chunks = result.scalars().all()
assert len(chunks) == data["data"]["chunk_count"]
assert all(c.workspace_id == setup["workspace"].id for c in chunks)
@pytest.mark.asyncio
async def test_delete_cascades_chunks(async_client: AsyncClient, kb_setup, db_session: AsyncSession):
setup = kb_setup
headers = _auth(setup["token"], str(setup["workspace"].id))
file_content = b"Some content for cascade test.\n" * 10
# Upload
upload_res = await async_client.post(
"/api/v1/knowledge/files",
files={"file": ("cascade.txt", io.BytesIO(file_content), "text/plain")},
headers=headers,
)
file_id = upload_res.json()["data"]["id"]
# Verify chunks exist
result = await db_session.execute(
select(KnowledgeChunk).where(KnowledgeChunk.knowledge_file_id == UUID(file_id))
)
assert len(result.scalars().all()) > 0
# Delete file
del_res = await async_client.delete(f"/api/v1/knowledge/files/{file_id}", headers=headers)
assert del_res.status_code == 200
# Verify chunks deleted
result = await db_session.execute(
select(KnowledgeChunk).where(KnowledgeChunk.knowledge_file_id == UUID(file_id))
)
assert len(result.scalars().all()) == 0
@pytest.mark.asyncio
async def test_chunk_retrieval_keyword(db_session: AsyncSession):
"""Test keyword-based chunk retrieval."""
ws_id = uuid4()
file_id = uuid4()
# Create test chunks
chunks_data = [
"Our premium pricing starts at $500 per month for the basic plan.",
"Customer support is available 24/7 via phone and email.",
"The enterprise plan includes advanced analytics and reporting.",
"Our team has over 10 years of experience in marketing automation.",
]
for i, text in enumerate(chunks_data):
db_session.add(KnowledgeChunk(
workspace_id=ws_id,
knowledge_file_id=file_id,
chunk_index=i,
content_text=text,
))
await db_session.flush()
# Search for pricing-related content
results = await retrieve_relevant_chunks(db_session, ws_id, "What is the pricing?")
assert len(results) > 0
assert any("pricing" in r.lower() for r in results)
|