File size: 7,347 Bytes
0b86477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Test script for MedSAM HuggingFace Space
Run this after deploying your Space to verify it works
"""

import requests
import json
import base64
import numpy as np
from PIL import Image
from io import BytesIO
import sys

# UPDATE THIS after deploying your Space!
SPACE_URL = "https://YOUR_USERNAME-medsam-inference.hf.space/api/predict"


def test_space_with_image(image_path: str, x: int, y: int):
    """
    Test the MedSAM Space with an image
    
    Args:
        image_path: Path to test image
        x: X coordinate for segmentation point
        y: Y coordinate for segmentation point
    """
    print(f"πŸ§ͺ Testing MedSAM Space: {SPACE_URL}")
    print(f"   Image: {image_path}")
    print(f"   Point: ({x}, {y})")
    print()
    
    try:
        # 1. Load and encode image
        print("πŸ“Έ Loading image...")
        with open(image_path, "rb") as f:
            image_bytes = f.read()
        
        image = Image.open(BytesIO(image_bytes))
        print(f"   Size: {image.size}")
        print(f"   Mode: {image.mode}")
        
        # Encode as base64
        img_base64 = base64.b64encode(image_bytes).decode()
        print(f"   Base64 size: {len(img_base64)} chars")
        print()
        
        # 2. Prepare points JSON
        print("πŸ“ Preparing points...")
        points_json = json.dumps({
            "coords": [[x, y]],
            "labels": [1],  # 1 = foreground
            "multimask_output": True
        })
        print(f"   Points JSON: {points_json}")
        print()
        
        # 3. Call API
        print("πŸš€ Calling Space API...")
        response = requests.post(
            SPACE_URL,
            json={
                "data": [
                    f"data:image/jpeg;base64,{img_base64}",
                    points_json
                ]
            },
            timeout=120
        )
        
        print(f"   Status code: {response.status_code}")
        
        if response.status_code != 200:
            print(f"❌ Error: {response.status_code}")
            print(f"   Response: {response.text}")
            return False
        
        print()
        
        # 4. Parse result
        print("πŸ“Š Parsing result...")
        result = response.json()
        
        # Gradio wraps output in data array
        if "data" not in result or len(result["data"]) == 0:
            print("❌ Error: Unexpected response format")
            print(f"   Response: {json.dumps(result, indent=2)}")
            return False
        
        output_json = result["data"][0]
        output = json.loads(output_json)
        
        if not output.get("success", False):
            print(f"❌ Error: {output.get('error', 'Unknown error')}")
            return False
        
        print("βœ… Success!")
        print(f"   Number of masks: {output['num_masks']}")
        print(f"   Scores: {output['scores']}")
        print()
        
        # 5. Process masks
        print("🎭 Processing masks...")
        for i, (mask_data, score) in enumerate(zip(output['masks'], output['scores'])):
            mask_array = np.array(mask_data['mask_data'], dtype=bool)
            print(f"   Mask {i+1}:")
            print(f"      Shape: {mask_array.shape}")
            print(f"      Score: {score:.4f}")
            print(f"      Pixels: {np.sum(mask_array)} / {mask_array.size}")
            print(f"      Coverage: {100 * np.sum(mask_array) / mask_array.size:.2f}%")
        
        # 6. Get best mask
        best_idx = np.argmax(output['scores'])
        best_mask = np.array(output['masks'][best_idx]['mask_data'], dtype=bool)
        best_score = output['scores'][best_idx]
        
        print()
        print(f"πŸ† Best mask: #{best_idx+1} (score: {best_score:.4f})")
        print()
        
        # 7. Save visualization
        print("πŸ’Ύ Saving visualization...")
        
        # Create visualization
        image_array = np.array(image)
        
        # Create colored mask overlay
        mask_overlay = np.zeros((*best_mask.shape, 3), dtype=np.uint8)
        mask_overlay[best_mask] = [255, 0, 0]  # Red
        
        # Blend with original image
        if len(image_array.shape) == 2:  # Grayscale
            image_array = np.stack([image_array] * 3, axis=-1)
        
        blended = image_array.copy()
        blended[best_mask] = (
            0.6 * image_array[best_mask] + 
            0.4 * mask_overlay[best_mask]
        ).astype(np.uint8)
        
        # Save
        output_path = "test_result_visualization.png"
        Image.fromarray(blended).save(output_path)
        print(f"   Saved: {output_path}")
        
        # Save mask only
        mask_path = "test_result_mask.png"
        Image.fromarray((best_mask * 255).astype(np.uint8)).save(mask_path)
        print(f"   Saved: {mask_path}")
        print()
        
        print("=" * 60)
        print("βœ… TEST PASSED! Your Space is working correctly!")
        print("=" * 60)
        
        return True
        
    except requests.exceptions.Timeout:
        print("❌ Error: Request timeout (>120 seconds)")
        print("   The Space might be sleeping or overloaded")
        print("   Try again in 30 seconds")
        return False
        
    except requests.exceptions.RequestException as e:
        print(f"❌ Error: Request failed: {e}")
        return False
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return False


def check_space_status(space_url: str):
    """Check if the Space is online"""
    print(f"πŸ” Checking Space status: {space_url}")
    
    try:
        # Try to access the Space homepage
        homepage_url = space_url.replace("/api/predict", "")
        response = requests.get(homepage_url, timeout=10)
        
        if response.status_code == 200:
            print("βœ… Space is online!")
            return True
        else:
            print(f"⚠️  Space returned status {response.status_code}")
            return False
            
    except requests.exceptions.RequestException as e:
        print(f"❌ Cannot reach Space: {e}")
        print("   Make sure you've deployed the Space and updated SPACE_URL")
        return False


if __name__ == "__main__":
    print("=" * 60)
    print("MedSAM HuggingFace Space Test")
    print("=" * 60)
    print()
    
    # Check if SPACE_URL is updated
    if "YOUR_USERNAME" in SPACE_URL:
        print("❌ Error: Please update SPACE_URL in this script!")
        print("   Replace YOUR_USERNAME with your HuggingFace username")
        print()
        print("   Example:")
        print('   SPACE_URL = "https://johndoe-medsam-inference.hf.space/api/predict"')
        sys.exit(1)
    
    # Check Space status
    check_space_status(SPACE_URL)
    print()
    
    # Get test image
    if len(sys.argv) < 2:
        print("Usage: python test_space.py <image_path> [x] [y]")
        print()
        print("Example:")
        print("   python test_space.py test_image.jpg 200 150")
        print()
        sys.exit(1)
    
    image_path = sys.argv[1]
    x = int(sys.argv[2]) if len(sys.argv) > 2 else 200
    y = int(sys.argv[3]) if len(sys.argv) > 3 else 150
    
    # Run test
    success = test_space_with_image(image_path, x, y)
    
    sys.exit(0 if success else 1)