from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from PIL import Image import numpy as np import tensorflow as tf import requests import io app = FastAPI(title="رتینوپاتی دیابتی API", description="تشخیص ۵ مرحله DR با مدل CNN") # --- دانلود مدل از Space قدیمی --- MODEL_URL = "https://huggingface.co/megsciip/eye1/resolve/main/eye_modelv2.tflite" MODEL_PATH = "eye_modelv2.tflite" def download_model(): if not tf.io.gfile.exists(MODEL_PATH): print("در حال دانلود مدل...") response = requests.get(MODEL_URL) response.raise_for_status() with open(MODEL_PATH, "wb") as f: f.write(response.content) print("مدل دانلود شد!") download_model() # --- لود مدل --- interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() IMG_SIZE = (224, 224) CLASS_NAMES = [ "۰ - بدون رتینوپاتی (چشم سالم)", "۱ - رتینوپاتی خفیف (Mild DR)", "۲ - رتینوپاتی متوسط (Moderate DR)", "۳ - رتینوپاتی شدید (Severe DR)", "۴ - رتینوپاتی تکثیری (Proliferative DR)" ] def preprocess_image(img: Image.Image): img = img.resize(IMG_SIZE) if img.mode != 'RGB': img = img.convert('RGB') img_array = np.array(img, dtype=np.float32) img_array = np.expand_dims(img_array, axis=0) img_array = tf.keras.applications.resnet50.preprocess_input(img_array) return img_array @app.post("/predict") async def predict(file: UploadFile = File(...)): if not file.content_type.startswith("image/"): raise HTTPException(400, detail="فقط فایل تصویری مجاز است!") try: contents = await file.read() img = Image.open(io.BytesIO(contents)) input_data = preprocess_image(img) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() predictions = interpreter.get_tensor(output_details[0]['index'])[0] predicted_idx = int(np.argmax(predictions)) confidence = float(predictions[predicted_idx]) result = { "diagnosis": CLASS_NAMES[predicted_idx], "confidence": f"{confidence:.1%}", "probabilities": { CLASS_NAMES[i].split(" - ")[1]: f"{float(predictions[i]):.1%}" for i in range(len(predictions)) }, "recommended_action": "مشاوره با چشم‌پزشک ضروری است." if predicted_idx > 0 else "چشم سالم به نظر می‌رسد." } return JSONResponse(result) except Exception as e: raise HTTPException(500, detail=f"خطا در پردازش: {str(e)}") @app.get("/") def home(): return {"message": "API تشخیص رتینوپاتی دیابتی فعال است! POST به /predict"} # --- راه‌اندازی سرور (ضروری برای Hugging Face) --- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)