BarkID / app.py
ThatHungarian's picture
Update app.py
aefe2aa verified
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import io
import os
from flask import Flask, request, jsonify, render_template_string
from flask_cors import CORS
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
MODEL_PATH = 'Model-79.85.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 224
PORT = int(os.environ.get("PORT", 7860)) # HF Spaces uses 7860
def load_model():
if not os.path.exists(MODEL_PATH):
logger.error(f"ERROR: Model file '{MODEL_PATH}' not found!")
return None, {}
try:
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model = models.mobilenet_v3_large(weights=None)
in_features = 960
model.classifier = nn.Sequential(
nn.Linear(in_features, 1280),
nn.Hardswish(inplace=True),
nn.Dropout(0.3, inplace=True),
nn.Linear(1280, 640),
nn.Hardswish(inplace=True),
nn.Dropout(0.2, inplace=True),
nn.Linear(640, 120)
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
idx_to_class = checkpoint.get('idx_to_class', {})
class_names = {int(k): v for k, v in idx_to_class.items()}
val_acc = checkpoint.get('best_top1', checkpoint.get('best_acc', 'N/A'))
if isinstance(val_acc, (int, float)):
logger.info(f"Model loaded! Classes: {len(class_names)}, Best Val Acc: {val_acc:.2f}%")
else:
logger.info(f"Model loaded! Classes: {len(class_names)}")
return model, class_names
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
return None, {}
model, IDX_TO_CLASS = load_model()
CLASSES = [IDX_TO_CLASS[i] for i in range(len(IDX_TO_CLASS))] if IDX_TO_CLASS else []
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict_with_tta(image, model, device):
model.eval()
transform_flip = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
orig_tensor = transform(image).unsqueeze(0).to(device)
flip_tensor = transform_flip(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs_orig = torch.nn.functional.softmax(model(orig_tensor), dim=1)
outputs_flip = torch.nn.functional.softmax(model(flip_tensor), dim=1)
outputs = (outputs_orig + outputs_flip) / 2
confidence, predicted = torch.max(outputs, 1)
return predicted.item(), float(confidence.item())
HTML_TEMPLATE = '''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
<meta name="mobile-web-app-capable" content="yes">
<title>Dog Breed Classifier</title>
<style>
* { box-sizing: border-box; -webkit-tap-highlight-color: transparent; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
margin: 0;
padding: 15px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
touch-action: manipulation;
}
.container {
max-width: 600px;
margin: 0 auto;
background: white;
border-radius: 20px;
padding: 25px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
}
h1 {
text-align: center;
color: #333;
margin: 0 0 10px 0;
font-size: 28px;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 25px;
font-size: 16px;
}
.model-badge {
text-align: center;
background: #e3f2fd;
padding: 10px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
color: #1976d2;
font-weight: 600;
}
.button-grid {
display: grid;
gap: 15px;
margin-bottom: 20px;
}
.btn {
background: #4CAF50;
color: white;
border: none;
padding: 18px;
border-radius: 12px;
font-size: 18px;
font-weight: 600;
cursor: pointer;
width: 100%;
transition: transform 0.2s, box-shadow 0.2s;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
display: block;
text-align: center;
text-decoration: none;
-webkit-touch-callout: none;
-webkit-user-select: none;
user-select: none;
}
.btn:active {
transform: translateY(2px) scale(0.98);
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.btn-camera { background: #2196F3; }
.btn-gallery { background: #FF9800; }
input[type="file"] {
position: absolute;
opacity: 0;
pointer-events: none;
width: 0;
height: 0;
}
.preview-container {
margin-top: 20px;
text-align: center;
display: none;
}
#preview {
max-width: 100%;
max-height: 300px;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
}
.loading {
display: none;
text-align: center;
margin-top: 20px;
}
.spinner {
border: 5px solid #f3f3f3;
border-top: 5px solid #3498db;
border-radius: 50%;
width: 50px;
height: 50px;
animation: spin 1s linear infinite;
margin: 0 auto 15px auto;
}
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
.result {
margin-top: 20px;
padding: 25px;
background: #f8f9fa;
border-radius: 15px;
display: none;
text-align: center;
border: 2px solid #e9ecef;
}
.breed-name {
font-size: 24px;
font-weight: bold;
margin: 10px 0;
color: #333;
line-height: 1.3;
text-transform: capitalize;
}
.confidence {
font-size: 20px;
color: #666;
font-weight: 500;
}
.top-k {
margin-top: 15px;
text-align: left;
background: white;
padding: 15px;
border-radius: 8px;
border: 1px solid #e0e0e0;
}
.top-k-title {
font-size: 14px;
color: #999;
margin-bottom: 8px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.breed-item {
display: flex;
justify-content: space-between;
padding: 5px 0;
border-bottom: 1px solid #f0f0f0;
font-size: 14px;
}
.breed-item:last-child { border-bottom: none; }
.breed-prob { font-weight: 600; color: #2196F3; }
.error {
margin-top: 15px;
padding: 15px;
background: #fee;
color: #c33;
border-radius: 8px;
border-left: 4px solid #c33;
display: none;
}
.status-badge {
text-align: center;
padding: 8px;
border-radius: 6px;
margin-bottom: 15px;
font-size: 13px;
font-weight: 600;
}
.status-ok { background: #e8f5e9; color: #2e7d32; }
.status-error { background: #ffebee; color: #c62828; }
</style>
</head>
<body>
<div class="container">
<div class="status-badge {{ 'status-ok' if model_loaded else 'status-error' }}">
{{ '✓ Model Ready' if model_loaded else '✗ Model Not Loaded' }}
</div>
<h1>Dog Breed Classifier</h1>
<div class="button-grid">
<label for="cameraInput" class="btn btn-camera">
📷 Take Photo
</label>
<label for="galleryInput" class="btn btn-gallery">
🖼️ Upload from Gallery
</label>
</div>
<input type="file" id="cameraInput" accept="image/*" capture="environment" onchange="handleFile(this)">
<input type="file" id="galleryInput" accept="image/*" onchange="handleFile(this)">
<div class="preview-container" id="previewContainer">
<img id="preview" alt="Selected dog">
</div>
<div class="loading" id="loading">
<div class="spinner"></div>
<div>Identifying breed...</div>
</div>
<div class="result" id="result">
<div class="breed-name" id="breedName"></div>
<div class="confidence" id="confidence"></div>
<div class="top-k" id="topK" style="display: none;">
<div class="top-k-title">Also possible:</div>
<div id="topKList"></div>
</div>
</div>
<div class="error" id="error"></div>
</div>
<script>
function handleFile(input) {
const file = input.files[0];
if (!file) return;
document.getElementById('error').style.display = 'none';
document.getElementById('result').style.display = 'none';
const reader = new FileReader();
reader.onload = function(e) {
const preview = document.getElementById('preview');
preview.src = e.target.result;
document.getElementById('previewContainer').style.display = 'block';
uploadImage(file);
};
reader.readAsDataURL(file);
input.value = '';
}
function uploadImage(file) {
document.getElementById('loading').style.display = 'block';
const formData = new FormData();
formData.append('image', file);
fetch('/predict', {
method: 'POST',
body: formData,
headers: {
'Accept': 'application/json',
}
})
.then(response => {
if (!response.ok) throw new Error('Server error: ' + response.status);
return response.json();
})
.then(data => {
document.getElementById('loading').style.display = 'none';
if (data.error) {
showError(data.error);
} else {
showResult(data);
}
})
.catch(error => {
document.getElementById('loading').style.display = 'none';
console.error("Upload error:", error);
showError("Failed to analyze: " + error.message);
});
}
function showResult(data) {
document.getElementById('breedName').textContent = data.class.replace(/_/g, ' ');
document.getElementById('confidence').textContent =
`Confidence: ${(data.confidence * 100).toFixed(1)}%`;
document.getElementById('result').style.display = 'block';
if (data.top3 && data.top3.length > 1) {
const list = document.getElementById('topKList');
list.innerHTML = '';
data.top3.slice(1).forEach(item => {
const div = document.createElement('div');
div.className = 'breed-item';
div.innerHTML = `<span>${item.class.replace(/_/g, ' ')}</span><span class="breed-prob">${(item.confidence * 100).toFixed(1)}%</span>`;
list.appendChild(div);
});
document.getElementById('topK').style.display = 'block';
}
setTimeout(() => {
document.getElementById('result').scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}, 100);
}
function showError(msg) {
const errorDiv = document.getElementById('error');
errorDiv.textContent = msg;
errorDiv.style.display = 'block';
errorDiv.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
</script>
</body>
</html>
'''
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE, model_loaded=model is not None)
@app.route('/predict', methods=['POST'])
def predict():
if model is None:
return jsonify({'error': 'Model not loaded'}), 500
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
file = request.files['image']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
try:
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
pred_idx, confidence = predict_with_tta(image, model, DEVICE)
breed_name = IDX_TO_CLASS.get(pred_idx, "Unknown")
with torch.no_grad():
tensor = transform(image).unsqueeze(0).to(DEVICE)
outputs = torch.nn.functional.softmax(model(tensor), dim=1)
probs, indices = torch.topk(outputs, k=3, dim=1)
top3 = []
for i in range(3):
idx = indices[0][i].item()
prob = probs[0][i].item()
top3.append({
'class': IDX_TO_CLASS.get(idx, "Unknown"),
'confidence': prob
})
return jsonify({
'class': breed_name,
'confidence': confidence,
'top3': top3,
'success': True
})
except Exception as e:
logger.error(f"Prediction error: {e}")
import traceback
logger.error(traceback.format_exc())
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
return jsonify({
'status': 'healthy',
'model_loaded': model is not None,
'model_type': 'MobileNetV3-Large',
'num_classes': len(CLASSES),
'input_size': IMAGE_SIZE,
'device': str(DEVICE)
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=PORT, debug=False)