from flask import Flask, session, request, redirect, url_for, render_template_string, send_file import datetime import os import secrets import torch from PIL import Image, ImageDraw from transformers import GroundingDinoProcessor from modeling_grounding_dino import GroundingDinoForObjectDetection from itertools import cycle import tempfile import io app = Flask(__name__) app.secret_key = os.environ.get('SECRET_KEY', secrets.token_hex(16)) SECRET_PASSWORD = "VeronaTrento25!" app.permanent_session_lifetime = datetime.timedelta(hours=24) # ===== AUTHENTICATION FUNCTIONS ===== def is_authenticated(): return session.get('authenticated', False) def require_auth(f): def decorated_function(*args, **kwargs): if not is_authenticated(): return redirect(url_for('login')) return f(*args, **kwargs) decorated_function.__name__ = f.__name__ return decorated_function # ===== ML MODEL SETUP ===== DEVICE = "cpu" model_id = "fushh7/llmdet_swin_tiny_hf" print(f"[INFO] Using device: {DEVICE}") print(f"[INFO] Loading model from {model_id}...") processor = GroundingDinoProcessor.from_pretrained(model_id) model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(DEVICE) model.eval() print("[INFO] Model loaded successfully.") # Pre-defined palette BOX_COLORS = [ "deepskyblue", "red", "lime", "dodgerblue", "cyan", "magenta", "yellow", "orange", "chartreuse" ] # ===== ML FUNCTIONS ===== def save_cropped_images(original_image, boxes, labels, scores): saved_paths = [] for i, (box, label, score) in enumerate(zip(boxes, labels, scores)): with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: filepath = tmp_file.name cropped_img = original_image.crop(box) cropped_img.save(filepath) saved_paths.append(filepath) return saved_paths def draw_boxes(image, boxes, labels, scores, colors=BOX_COLORS, font_size=16): colour_cycle = cycle(colors) draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("arial.ttf", size=font_size) except: font = ImageFont.load_default() label_to_colour = {} for box, label, score in zip(boxes, labels, scores): colour = label_to_colour.setdefault(label, next(colour_cycle)) x_min, y_min, x_max, y_max = map(int, box) draw.rectangle([x_min, y_min, x_max, y_max], outline=colour, width=2) text = f"{label} ({score:.3f})" text_bbox = draw.textbbox((0, 0), text, font=font) text_width = text_bbox[2] - text_bbox[0] text_height = text_bbox[3] - text_bbox[1] bg_coords = [x_min, y_min - text_height - 4, x_min + text_width + 4, y_min] draw.rectangle(bg_coords, fill=colour) draw.text((x_min + 2, y_min - text_height - 2), text, fill="black", font=font) return image def resize_image_max_dimension(image, max_size=1024): width, height = image.size if max(width, height) <= max_size: return image ratio = max_size / max(width, height) new_width = int(width * ratio) new_height = int(height * ratio) return image.resize((new_width, new_height), Image.Resampling.LANCZOS) def detect_and_draw(img, text_query, box_threshold=0.14, text_threshold=0.13): text_query = text_query.lower() img = resize_image_max_dimension(img, max_size=1024) inputs = processor(images=img, text=text_query, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_grounded_object_detection( outputs, inputs.input_ids, text_threshold=text_threshold, target_sizes=[img.size[::-1]] )[0] img_out = img.copy() img_out = draw_boxes( img_out, boxes=results["boxes"].cpu().numpy(), labels=results.get("text_labels", results.get("labels", [])), scores=results["scores"] ) crop_paths = save_cropped_images( img, boxes=results["boxes"].cpu().numpy(), labels=results.get("text_labels", results.get("labels", [])), scores=results["scores"] ) return img_out, crop_paths # ===== FLASK ROUTES ===== @app.route('/') @require_auth def index(): return render_template_string(''' Student Finder - Protetto

🎓 Student Finder

Carica una foto di classe e trova gli studenti

🔓 Logout
Testo in lowercase, ogni concetto termina con '.' (es. 'heads. faces.')
{% if result_image %}

Risultati:

Risultato {% if crops %}

Ritagli individuati ({{ crops|length }}):

{% endif %}
{% endif %}
''', box_threshold=0.14, text_threshold=0.13) @app.route('/detect', methods=['POST']) @require_auth def detect(): if 'image' not in request.files: return redirect(url_for('index')) image_file = request.files['image'] if image_file.filename == '': return redirect(url_for('index')) try: # Process image image = Image.open(image_file.stream).convert('RGB') text_query = request.form.get('text_query', 'heads.') box_threshold = float(request.form.get('box_threshold', 0.14)) text_threshold = float(request.form.get('text_threshold', 0.13)) # Run detection result_image, crop_paths = detect_and_draw(image, text_query, box_threshold, text_threshold) # Convert images to base64 for display import base64 # Convert result image to base64 img_buffer = io.BytesIO() result_image.save(img_buffer, format='JPEG') result_b64 = base64.b64encode(img_buffer.getvalue()).decode() # Convert crops to base64 crops_b64 = [] for crop_path in crop_paths: with open(crop_path, 'rb') as f: crop_b64 = base64.b64encode(f.read()).decode() crops_b64.append(crop_b64) # Cleanup temp file os.unlink(crop_path) return render_template_string(''' Risultati - Student Finder

🎓 Risultati Student Finder

🔓 Logout
← Nuova Analisi

Immagine con bounding box:

Risultato {% if crops %}

Ritagli individuati ({{ crops|length }}):

{% else %}

Nessun ritaglio individuato.

{% endif %}
''', result_image=result_b64, crops=crops_b64) except Exception as e: return f"Errore durante l'elaborazione: {str(e)}", 500 @app.route('/login', methods=['GET', 'POST']) def login(): if is_authenticated(): return redirect(url_for('index')) error = None if request.method == 'POST': if request.form.get('password') == SECRET_PASSWORD: session.permanent = True session['authenticated'] = True return redirect(url_for('index')) else: error = "❌ Password errata. Riprova." return render_template_string(''' Login - Student Finder

🔒 Student Finder - Accesso Protetto

Inserisci la password per accedere

{% if error %}
{{ error }}
{% endif %}
''', error=error) @app.route('/logout') def logout(): session.clear() return redirect(url_for('login')) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False)