File size: 4,404 Bytes
63be620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
"""
Flask web server to visualize inference results for 2-subject cases.
Port: 7023
"""

import os
import json
from pathlib import Path
from flask import Flask, render_template, send_from_directory
import base64

app = Flask(__name__)

# Paths
DATASET_FILE = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/cuboids_segmentation.jsonl"
DATASET_ROOT = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval"
RESULTS_DIR = "/archive/vaibhav.agrawal/a-bev-of-the-latents/VAL/results/omini_seg_baseline_r2_epoch-0_checkpoint-20000"

def load_2_subject_cases():
    """Load all 2-subject cases from the dataset."""
    cases = []
    with open(DATASET_FILE, 'r') as f:
        for idx, line in enumerate(f):
            data = json.loads(line)
            if len(data['subjects']) == 2:
                cases.append({
                    'dataset_index': idx,
                    'subjects': data['subjects'],
                    'prompt': data['prompt'],
                    'target': data['target'],
                    'cv': data['cv']
                })
    return cases

# Load cases on startup
TWO_SUBJECT_CASES = load_2_subject_cases()
print(f"Loaded {len(TWO_SUBJECT_CASES)} 2-subject cases")

def get_image_path(case, image_type):
    """Get the path for a specific image type."""
    if image_type == 'ground_truth':
        return os.path.join(DATASET_ROOT, case['target'])
    elif image_type == 'segmentation':
        return os.path.join(DATASET_ROOT, case['cv'])
    elif image_type == 'generated':
        # Find the generated image in results
        viz_dir = os.path.join(RESULTS_DIR, 'generated_images')
        # Pattern: sample_{sample_idx:04d}_idx_{dataset_index}_seed_{seed}.jpg
        # We need to find the file that matches the dataset_index
        if os.path.exists(viz_dir):
            for filename in os.listdir(viz_dir):
                if f"_idx_{case['dataset_index']}_" in filename:
                    return os.path.join(viz_dir, filename)
    return None

@app.route('/')
def index():
    """Main page showing the first 2-subject case."""
    return show_case(0)

@app.route('/case/<int:case_idx>')
def show_case(case_idx):
    """Display a specific case."""
    if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
        return "Case not found", 404
    
    case = TWO_SUBJECT_CASES[case_idx]
    
    # Get image paths
    gt_path = get_image_path(case, 'ground_truth')
    seg_path = get_image_path(case, 'segmentation')
    gen_path = get_image_path(case, 'generated')
    
    # Check if files exist
    gt_exists = os.path.exists(gt_path) if gt_path else False
    seg_exists = os.path.exists(seg_path) if seg_path else False
    gen_exists = os.path.exists(gen_path) if gen_path else False
    
    return render_template('viewer.html',
                         case_idx=case_idx,
                         total_cases=len(TWO_SUBJECT_CASES),
                         subjects=', '.join(case['subjects']),
                         prompt=case['prompt'].replace('PLACEHOLDER', ', '.join(case['subjects'])),
                         dataset_index=case['dataset_index'],
                         gt_exists=gt_exists,
                         seg_exists=seg_exists,
                         gen_exists=gen_exists,
                         prev_idx=case_idx - 1 if case_idx > 0 else None,
                         next_idx=case_idx + 1 if case_idx < len(TWO_SUBJECT_CASES) - 1 else None)

@app.route('/image/<int:case_idx>/<image_type>')
def serve_image(case_idx, image_type):
    """Serve the requested image."""
    if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
        return "Case not found", 404
    
    case = TWO_SUBJECT_CASES[case_idx]
    image_path = get_image_path(case, image_type)
    
    if image_path and os.path.exists(image_path):
        directory = os.path.dirname(image_path)
        filename = os.path.basename(image_path)
        return send_from_directory(directory, filename)
    else:
        return "Image not found", 404

if __name__ == '__main__':
    # Create templates directory if it doesn't exist
    os.makedirs('templates', exist_ok=True)
    
    # Run server on all interfaces (0.0.0.0) for remote access
    print(f"Starting server on port 7023...")
    print(f"Access at: http://<your-host-ip>:7023")
    app.run(host='0.0.0.0', port=7023, debug=True)