from flask import Flask, request, jsonify from transformers import pipeline from PIL import Image import io import fitz # PyMuPDF import os from werkzeug.utils import secure_filename app = Flask(__name__) # Load model and processor using pipeline model_name = "AsmaaElnagger/Diabetic_RetinoPathy_detection" classifier = pipeline("image-classification", model=model_name) # PDF to image conversion def pdf_to_images_pymupdf(pdf_data): try: pdf_document = fitz.open(stream=pdf_data, filetype="pdf") images = [] for page_num in range(pdf_document.page_count): page = pdf_document.load_page(page_num) pix = page.get_pixmap() img_data = pix.tobytes("jpeg") images.append(img_data) return images except Exception as e: print(f"Error converting PDF: {e}") return None # File classification function (modified for API) def classify_file(file_path): try: file_ext = os.path.splitext(file_path)[-1].lower() if file_ext in ['.jpg', '.jpeg', '.png', '.gif']: # Handle image upload image = Image.open(file_path).convert("RGB") result = classifier(image)[0] # Get the top prediction return { "prediction": result["label"], "confidence": result["score"] * 100, } elif file_ext == '.pdf': # Handle PDF upload with open(file_path, "rb") as f: pdf_data = f.read() images = pdf_to_images_pymupdf(pdf_data) if images: image = Image.open(io.BytesIO(images[0])).convert("RGB") result = classifier(image)[0] # Get the top prediction return { "prediction": result["label"], "confidence": result["score"] * 100, } else: return {"error": "PDF conversion failed."} else: return {"error": "Unsupported file type."} except Exception as e: return {"error": f"An error occurred: {e}"} # API endpoint for file classification @app.route('/classify', methods=['POST']) def classify(): if 'file' not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"error": "No file selected"}), 400 filename = secure_filename(file.filename) filepath = os.path.join('/tmp', filename) # Save to a temporary location file.save(filepath) result = classify_file(filepath) os.remove(filepath) # remove temp file return jsonify(result), 200 # Return JSON response if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)