Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.openapi.docs import get_swagger_ui_html | |
| from typing import List, Optional | |
| import io | |
| from PIL import Image | |
| import numpy as np | |
| from model import SkinClassifier | |
| # Инициализация FastAPI | |
| app = FastAPI( | |
| title="Skin Classification API", | |
| description="API для классификации кожных заболеваний (clear, acne, ros, black)", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # Инициализация модели | |
| try: | |
| # Путь к модели в Hugging Face Spaces | |
| MODEL_PATH = "model/stage1_skin_classifier.pth" | |
| classifier = SkinClassifier(MODEL_PATH) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| classifier = None | |
| async def root(): | |
| """Главная страница с информацией об API""" | |
| return "go to /docs" | |
| async def predict( | |
| file: UploadFile = File(..., description="Изображение для классификации (JPG, PNG)"), | |
| return_image: Optional[bool] = False | |
| ): | |
| """ | |
| Классификация одного изображения | |
| - **file**: Изображение в формате JPG или PNG | |
| - **return_image**: Возвращать ли base64 изображение (по умолчанию False) | |
| """ | |
| if classifier is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Проверка формата файла | |
| allowed_extensions = {".jpg", ".jpeg", ".png", ".bmp"} | |
| file_ext = os.path.splitext(file.filename)[1].lower() | |
| if file_ext not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file format. Allowed: {allowed_extensions}" | |
| ) | |
| try: | |
| # Чтение файла | |
| contents = await file.read() | |
| # Предсказание | |
| result = classifier.predict(contents) | |
| response_data = { | |
| "filename": file.filename, | |
| "prediction": result["predicted_class"], | |
| "confidence": result["confidence"], | |
| "probabilities": result["all_probabilities"] | |
| } | |
| # Добавляем base64 изображение если нужно | |
| if return_image: | |
| from io import BytesIO | |
| import base64 | |
| img = Image.open(BytesIO(contents)) | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG" if file_ext in {".jpg", ".jpeg"} else "PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| response_data["image_base64"] = f"data:image/jpeg;base64,{img_str}" | |
| return response_data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| # Middleware для обработки CORS | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) |