Spaces:
Paused
Paused
| from flask import Flask, render_template, request, redirect, url_for | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| from PIL import Image, ImageDraw | |
| import torch | |
| import os | |
| import uuid | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = Flask(__name__) | |
| # Set upload folder | |
| UPLOAD_FOLDER = 'static/uploads' | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| # Load DETR model and processor | |
| logger.info("Loading DETR model and processor...") | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| logger.info("Model and processor loaded successfully.") | |
| def index(): | |
| return render_template('index.html') | |
| def upload_file(): | |
| if 'file' not in request.files: | |
| logger.warning("No file part in request.") | |
| return redirect(request.url) | |
| file = request.files['file'] | |
| if file.filename == '': | |
| logger.warning("No file selected.") | |
| return redirect(request.url) | |
| try: | |
| # Save uploaded file | |
| filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1] | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| logger.info(f"File saved: {filename}") | |
| # Process image | |
| image = Image.open(filepath).convert("RGB") | |
| image = image.resize((800, 600)) # Resize for performance | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| # Post-process outputs | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
| # Draw bounding boxes | |
| draw = ImageDraw.Draw(image) | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| box = [round(i, 2) for i in box.tolist()] | |
| label_str = model.config.id2label[label.item()] | |
| draw.rectangle(box, outline="red", width=3) | |
| draw.text((box[0], box[1]), f"{label_str}: {score:.2f}", fill="red") | |
| # Save output image | |
| output_filename = f"output_{filename}" | |
| output_filepath = os.path.join(app.config['UPLOAD_FOLDER'], output_filename) | |
| image.save(output_filepath) | |
| logger.info(f"Processed image saved: {output_filename}") | |
| return render_template('results.html', | |
| original_image=url_for('static', filename=f'uploads/{filename}'), | |
| processed_image=url_for('static', filename=f'uploads/{output_filename}')) | |
| except Exception as e: | |
| logger.error(f"Error processing file: {str(e)}") | |
| return render_template('index.html', error=f"Error processing file: {str(e)}") | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |