Spaces:
Runtime error
Runtime error
| import yolov5 | |
| import random | |
| from pathlib import Path | |
| import paho.mqtt.client as mqtt | |
| import mysql.connector | |
| from datetime import datetime | |
| # MySQL Configuration (TiDB Serverless) | |
| db_config = { | |
| 'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com', | |
| 'user': '5ztcqT1EBcgYB5u.root', | |
| 'password': '92t7gF7zq3Oz9eKb', | |
| 'database': 'satwa', | |
| 'port': 4000, | |
| 'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces | |
| 'ssl_verify_cert': True, | |
| 'ssl_verify_identity': True | |
| } | |
| # Connect to MySQL | |
| try: | |
| db = mysql.connector.connect(**db_config) | |
| cursor = db.cursor() | |
| print("Connected to MySQL database successfully.") | |
| except mysql.connector.Error as err: | |
| print(f"Failed to connect to MySQL: {err}") | |
| exit() | |
| # Create table | |
| create_table_query = """ | |
| CREATE TABLE IF NOT EXISTS garbage_classification ( | |
| id INT AUTO_INCREMENT PRIMARY KEY, | |
| timestamp DATETIME, | |
| payload_value INT, | |
| predicted_class VARCHAR(255) | |
| ) | |
| """ | |
| try: | |
| cursor.execute(create_table_query) | |
| db.commit() | |
| print("Table checked/created.") | |
| except mysql.connector.Error as err: | |
| print(f"Error creating table: {err}") | |
| exit() | |
| # Force CPU (Spaces free tier is CPU-only) | |
| device = 'cpu' | |
| print(f"Using device: {device}") | |
| # Load model | |
| model = yolov5.load('keremberke/yolov5m-garbage') | |
| model.to(device) | |
| # Model parameters | |
| model.conf = 0.25 | |
| model.iou = 0.45 | |
| model.agnostic = False | |
| model.multi_label = False | |
| model.max_det = 1000 | |
| class_names = model.names | |
| # Base directory in Spaces | |
| base_dir = "/data/garbage_classification" | |
| base_path = Path(base_dir) | |
| if not base_path.exists() or not base_path.is_dir(): | |
| print(f"Error: Directory '{base_dir}' not found.") | |
| exit() | |
| subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()]) | |
| if not subfolders: | |
| print(f"Error: No subfolders found in '{base_dir}'.") | |
| exit() | |
| print(f"Found {len(subfolders)} subfolders.") | |
| def get_folder_index(value): | |
| if 1 <= value <= 11: return 0 | |
| elif 12 <= value <= 22: return 1 | |
| elif 23 <= value <= 33: return 2 | |
| elif 34 <= value <= 44: return 3 | |
| elif 45 <= value <= 55: return 4 | |
| elif 56 <= value <= 66: return 5 | |
| elif 67 <= value <= 77: return 6 | |
| elif 78 <= value <= 88: return 7 | |
| elif 89 <= value <= 99: return 8 | |
| else: return -1 | |
| def process_random_image(subfolder, payload_value): | |
| image_extensions = ('.jpg', '.jpeg', '.png', '.bmp') | |
| images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions] | |
| if not images: | |
| print(f"No images in {subfolder}") | |
| return None | |
| random_image = random.choice(images) | |
| print(f"Selected image: {random_image}") | |
| results = model(str(random_image), size=640) | |
| predictions = results.pred[0] | |
| predicted_class = "None" | |
| if len(predictions) > 0: | |
| scores = predictions[:, 4] | |
| top_idx = scores.argmax() | |
| top_score = scores[top_idx] | |
| top_cat = int(predictions[top_idx, 5]) | |
| predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown" | |
| print(f"Class: {predicted_class}, Confidence: {top_score:.2f}") | |
| else: | |
| print("No objects detected.") | |
| timestamp = datetime.now() | |
| insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" | |
| try: | |
| cursor.execute(insert_query, (timestamp, payload_value, predicted_class)) | |
| db.commit() | |
| print("Data saved to database.") | |
| except mysql.connector.Error as err: | |
| print(f"Error saving to database: {err}") | |
| return predicted_class | |
| def on_message(client, userdata, msg): | |
| try: | |
| value = int(msg.payload.decode()) | |
| print(f"Received MQTT value: {value}") | |
| folder_idx = get_folder_index(value) | |
| if folder_idx != -1 and folder_idx < len(subfolders): | |
| selected_folder = subfolders[folder_idx] | |
| print(f"Mapped to folder: {selected_folder}") | |
| process_random_image(selected_folder, value) | |
| else: | |
| timestamp = datetime.now() | |
| insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" | |
| cursor.execute(insert_query, (timestamp, value, "Out of Range")) | |
| db.commit() | |
| print("Value out of range. Saved to database.") | |
| except ValueError: | |
| print("Invalid payload - must be an integer") | |
| # MQTT Setup | |
| broker = "test.mosquitto.org" | |
| topic = "garbage/index" | |
| client = mqtt.Client() | |
| client.on_message = on_message | |
| client.connect(broker, 1883, 60) | |
| client.subscribe(topic) | |
| print(f"Subscribed to {topic}, waiting for messages...") | |
| client.loop_forever() # Runs indefinitely |