File size: 11,771 Bytes
8eab558
623a404
8eab558
 
623a404
8eab558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623a404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8eab558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623a404
 
 
 
 
 
8eab558
 
 
 
623a404
 
 
 
 
8eab558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
import sys
import zipfile
import pandas as pd
import numpy as np
from flask import Flask, request, redirect, url_for, send_from_directory, flash, render_template
from werkzeug.utils import secure_filename
from tqdm import tqdm
from sklearn.metrics import classification_report, precision_recall_fscore_support
from inference_utils import DiamondInference
from dotenv import load_dotenv

# Load local environment variables from .env
load_dotenv()

app = Flask(__name__)
app.secret_key = "supersecretkey"

# Hugging Face Hub Integration
HF_REPO_ID = os.getenv("HF_REPO_ID", "WebashalarForML/Diamcol")
HF_TOKEN = os.getenv("HF_TOKEN")

# Model Configuration
MODEL_ID = "322c4f4d"
MODEL_NAME = f"model_vit_robust_{MODEL_ID}.keras"

def download_model_from_hf():
    from huggingface_hub import hf_hub_download
    print("[INFO] Checking model files from Hugging Face...")
    
    # Model file
    if not os.path.exists(MODEL_NAME):
        print(f"[INFO] Downloading {MODEL_NAME}...")
        hf_hub_download(repo_id=HF_REPO_ID, filename=MODEL_NAME, token=HF_TOKEN, local_dir=".")

    # Encoder files (Matches names in inference_utils.py)
    encoder_files = [
        f"hyperparameters_{MODEL_ID}.pkl",
        f"cat_encoders_{MODEL_ID}.pkl",
        f"num_scaler_{MODEL_ID}.pkl",
        f"target_encoder_{MODEL_ID}.pkl",
        f"norm_stats_{MODEL_ID}.pkl"
    ]
    os.makedirs("encoder", exist_ok=True)
    for f in encoder_files:
        f_path = os.path.join("encoder", f)
        if not os.path.exists(f_path):
            print(f"[INFO] Downloading {f}...")
            # Note: Assuming the structure on HF is encoder/filename
            hf_hub_download(repo_id=HF_REPO_ID, filename=f"encoder/{f}", token=HF_TOKEN, local_dir=".")

UPLOAD_FOLDER = 'uploads'
RESULTS_FOLDER = 'results'
EXTRACT_FOLDER = os.path.join(UPLOAD_FOLDER, 'extracted_images')

for folder in [UPLOAD_FOLDER, RESULTS_FOLDER, EXTRACT_FOLDER]:
    if not os.path.exists(folder):
        os.makedirs(folder)

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024  # 500MB max upload

# Global inference object (lazy loaded)
model_path = MODEL_NAME
encoder_dir = "encoder"
infer_engine = None   

def get_inference_engine():
    global infer_engine
    if infer_engine is None:
        # Try downloading if missing (for Docker/HF Spaces environment)
        try:
            download_model_from_hf()
        except Exception as e:
            print(f"[WARNING] Could not download from HF: {e}. Expecting local files.")
        
        infer_engine = DiamondInference(model_path, encoder_dir, MODEL_ID)
        
        # Warmup prediction to initialize TF graph and prevent "stuck" feeling on first stone
        print("[INFO] Warming up Inference Engine...")
        try:
            # Create a dummy row and zero patches for warmup
            dummy_row = {"StoneType": "NATURAL", "Color": "D", "Brown": "N", "BlueUv": "N", "GrdType": "GIA", "Carat": 1.0, "Result": "D"}
            # We don't need a real image for warmup, just a pass through predict
            # We'll mock process_image to return zeros
            orig_process = infer_engine.process_image
            try:
                infer_engine.process_image = lambda path, tta_transform=None: np.zeros(infer_engine.hp["flat_patches_shape"], dtype=np.float32)
                infer_engine.predict(dummy_row, "warmup.jpg", use_tta=False)
            finally:
                infer_engine.process_image = orig_process
            print("[INFO] Warmup complete.")
        except Exception as e:
            print(f"[WARNING] Warmup failed: {e}")
            
    return infer_engine

@app.route('/flush', methods=['POST'])
def flush_data():
    import shutil
    try:
        # Clear uploads folder
        for filename in os.listdir(UPLOAD_FOLDER):
            file_path = os.path.join(UPLOAD_FOLDER, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')
        
        # Re-create EXTRACT_FOLDER as it might have been deleted if it was a sub-dir
        if not os.path.exists(EXTRACT_FOLDER):
            os.makedirs(EXTRACT_FOLDER)

        # Clear results folder
        for filename in os.listdir(RESULTS_FOLDER):
            file_path = os.path.join(RESULTS_FOLDER, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print(f'Failed to delete {file_path}. Reason: {e}')

        flash('All data flushed successfully.')
    except Exception as e:
        flash(f'Error during flushing: {e}')
    
    return redirect(url_for('index'))

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/upload', methods=['POST'])
def upload_files():
    if 'zip_file' not in request.files or 'excel_file' not in request.files:
        flash('Both Zip and Excel files are required.')
        return redirect(request.url)
    
    zip_file = request.files['zip_file']
    excel_file = request.files['excel_file']
    
    if zip_file.filename == '' or excel_file.filename == '':
        flash('No selected file')
        return redirect(request.url)

    # Save and Extract Zip
    zip_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(zip_file.filename))
    zip_file.save(zip_path)
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_FOLDER)
    
    # Process Excel
    excel_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(excel_file.filename))
    excel_file.save(excel_path)
    df = pd.read_excel(excel_path)
    
    # Inference Logic
    engine = get_inference_engine()
    
    # Pre-cache all image paths for faster searching
    all_extracted_files = []
    for root, dirs, files in os.walk(EXTRACT_FOLDER):
        for f in files:
            if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                all_extracted_files.append(os.path.join(root, f))
    
    print(f"[INFO] Found {len(all_extracted_files)} images in extraction folder.")

    # Identifing ground truth for metrics
    y_true = []
    y_pred = []
    
    print(f"[INFO] Initializing Inference Pipeline for {len(df)} stones...")
    sys.stdout.flush()
    
    # Progress bar with direct stdout for Gunicorn visibility
    pbar = tqdm(df.iterrows(), total=len(df), desc="Inference Progress", file=sys.stdout)
    
    for index, row in pbar:
        l_code = str(row.get('L_Code', '')).split('.')[0]
        sr_no = str(row.get('SrNo', '')).split('.')[0]
        stone_id = str(row.get('Stone_Id', ''))
        
        # Log currently processing stone for "aliveness" verification
        if index % 5 == 0:
            print(f"[PROC] Stone {index+1}/{len(df)}: {l_code}")
            sys.stdout.flush()
        
        img_path = None
        for full_path in all_extracted_files:
            fname = os.path.basename(full_path)
            if l_code in fname and sr_no in fname:
                img_path = full_path
                break
        
        if not img_path and stone_id != 'nan' and stone_id:
            for full_path in all_extracted_files:
                if stone_id in os.basename(full_path):
                    img_path = full_path
                    break
        
        if img_path:
            prediction = engine.predict(row, img_path)
            # Store filename relative to EXTRACT_FOLDER for web serving
            web_path = os.path.relpath(img_path, start=EXTRACT_FOLDER)
            df.at[index, 'Predicted_FGrdCol'] = prediction
            df.at[index, 'Image_Path'] = web_path
            
            # If ground truth exists, collect it
            if 'FGrdCol' in row and pd.notna(row['FGrdCol']):
                y_true.append(str(row['FGrdCol']))
                y_pred.append(str(prediction))
        else:
            df.at[index, 'Predicted_FGrdCol'] = "Image Not Found"
            df.at[index, 'Image_Path'] = "N/A"

    # Calculate Metrics if ground truth is available
    metrics = None
    if y_true:
        report_dict = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
        
        # Clean up the report for better display
        class_metrics = []
        labels = sorted(list(set(y_true) | set(y_pred)))
        
        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(y_true, y_pred, labels=labels)
        
        for label, scores in report_dict.items():
            if label not in ['accuracy', 'macro avg', 'weighted avg']:
                class_metrics.append({
                    'label': label,
                    'precision': round(scores['precision'], 4),
                    'recall': round(scores['recall'], 4),
                    'f1': round(scores['f1-score'], 4),
                    'support': scores['support']
                })
        
        metrics = {
            'accuracy': round(report_dict['accuracy'], 4),
            'class_metrics': class_metrics,
            'weighted_avg': report_dict['weighted avg'],
            'macro_avg': report_dict['macro avg'],
            'precision': round(report_dict['weighted avg']['precision'], 4),
            'recall': round(report_dict['weighted avg']['recall'], 4),
            'f1': round(report_dict['weighted avg']['f1-score'], 4),
            'macro_f1': round(report_dict['macro avg']['f1-score'], 4),
            'macro_precision': round(report_dict['macro avg']['precision'], 4),
            'macro_recall': round(report_dict['macro avg']['recall'], 4),
            'confusion_matrix': {
                'labels': labels,
                'matrix': cm.tolist()
            }
        }

    # Model parameters (features used for prediction)
    model_features = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Carat", "Result"]
    
    # Identify "out of box" features - only if they actually contain data
    potential_oob = ['FancyYellow', 'Type2A', 'YellowUv']
    out_of_box_cols = []
    for col in potential_oob:
        if col in df.columns:
            # Check if there is at least one non-null/non-empty value
            if df[col].dropna().astype(str).str.strip().replace(['nan', 'None', ''], pd.NA).notna().any():
                out_of_box_cols.append(col)
    
    output_filename = f"report_{secure_filename(excel_file.filename)}"
    output_path = os.path.join(RESULTS_FOLDER, output_filename)
    df.to_excel(output_path, index=False)
    
    return render_template('report.html', 
                           report_data=df.to_dict(orient='records'), 
                           report_file=output_filename,
                           out_of_box_cols=out_of_box_cols,
                           model_features=model_features,
                           metrics=metrics)

@app.route('/download/<filename>')
def download_file(filename):
    return send_from_directory(RESULTS_FOLDER, filename)

@app.route('/image/<path:filename>')
def serve_image(filename):
    return send_from_directory(EXTRACT_FOLDER, filename)

if __name__ == '__main__':
    app.run(debug=True)