File size: 3,953 Bytes
2fe2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""HF storage unit tests without filesystem persistence."""

from __future__ import annotations

import os
import unittest
from unittest.mock import patch

os.environ.setdefault("STORAGE_MODE", "HF")
os.environ.setdefault("HF_TOKEN", "hf_test_token")
os.environ.setdefault("HF_REPO_ID", "test-user/docvault-storage")
os.environ.setdefault("HF_REPO_TYPE", "dataset")
os.environ.setdefault("SECRET_KEY", "test-secret")

from server.storage import factory
from server.storage.hf import HuggingFaceStorageManager


class FakeCommit:
    def __init__(self, commit_id: str, title: str):
        self.commit_id = commit_id
        self.title = title
        self.message = title
        self.created_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
        self.authors = ["test"]


class FakeHfApi:
    def __init__(self, token=None):
        self.token = token
        self.files = {}
        self.commits = []
        self.repo_created = False

    def repo_info(self, repo_id, repo_type):
        if not self.repo_created:
            raise RuntimeError("repo missing")
        return {"repo_id": repo_id}

    def create_repo(self, repo_id, repo_type, private=True, exist_ok=True):
        self.repo_created = True

    def list_repo_files(self, repo_id, repo_type):
        return sorted(self.files.keys())

    def create_commit(self, repo_id, repo_type, operations, commit_message):
        for operation in operations:
            operation_type = operation.__class__.__name__
            if operation_type == "CommitOperationAdd":
                self.files[operation.path_in_repo] = bytes(operation.path_or_fileobj)
            elif operation_type == "CommitOperationDelete":
                self.files.pop(operation.path_in_repo, None)
            elif operation_type == "CommitOperationCopy":
                self.files[operation.path_in_repo] = self.files[operation.src_path_in_repo]
        self.commits.append(FakeCommit(f"commit-{len(self.commits) + 1}", commit_message))

    def delete_file(self, path_in_repo, repo_id, repo_type, commit_message):
        self.files.pop(path_in_repo, None)
        self.commits.append(FakeCommit(f"commit-{len(self.commits) + 1}", commit_message))

    def list_repo_commits(self, repo_id, repo_type):
        return list(reversed(self.commits))


class HfStorageManagerTestCase(unittest.TestCase):
    def setUp(self):
        self.api = FakeHfApi()
        self.hf_patcher = patch("server.storage.hf.HfApi", return_value=self.api)
        self.hf_patcher.start()
        factory._storage_instance = None
        self.storage = HuggingFaceStorageManager()

    def tearDown(self):
        factory._storage_instance = None
        self.hf_patcher.stop()

    def test_create_folder_uses_gitkeep_marker(self):
        result = self.storage.create_folder("user1", "Projects")
        self.assertTrue(result["success"])
        self.assertIn("user1/Projects/.gitkeep", self.api.files)

    def test_rename_file(self):
        self.storage.upload_file("user1", "Projects", "a.txt", b"hello")
        result = self.storage.rename_file("user1", "Projects/a.txt", "b.txt")
        self.assertTrue(result["success"])
        self.assertIn("user1/Projects/b.txt", self.api.files)
        self.assertNotIn("user1/Projects/a.txt", self.api.files)

    def test_delete_folder_removes_nested_files(self):
        self.storage.upload_file("user1", "Projects/Sub", "a.txt", b"hello")
        result = self.storage.delete_folder("user1", "Projects")
        self.assertTrue(result["success"])
        self.assertFalse(any(path.startswith("user1/Projects/") for path in self.api.files))

    def test_local_mode_is_rejected_by_factory(self):
        factory._storage_instance = None
        with patch("server.storage.factory.config.STORAGE_MODE", "LOCAL"):
            with self.assertRaises(RuntimeError):
                factory.get_storage()


if __name__ == "__main__":
    unittest.main()