from fastapi import FastAPI, HTTPException from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline from PIL import Image from bs4 import BeautifulSoup import base64 import io import requests app = FastAPI(title="STOA Chest X-Ray API") # --- CORS --- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- MODEL LOADING --- print("Booting Pulmonology Agent. Loading ViT model into memory...") pipe = pipeline("image-classification", model="dima806/chest_xray_pneumonia_detection") print("Agent Ready!") # --- REQUEST SCHEMA --- class PredictRequest(BaseModel): image: str | None = None image_url: str | None = None # --- SMART FETCHER HELPER --- def get_image_from_any_url(url: str): """Smart fetcher that handles both raw images and HTML webpages.""" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", "Referer": "https://google.com" } # 1. Fetch whatever is at the URL response = requests.get(url, headers=headers, timeout=10) if response.status_code != 200: raise Exception(f"Site blocked us (Error {response.status_code})") content_type = response.headers.get('Content-Type', '').lower() # 2. If it's already an image, just return it if content_type.startswith('image/'): return Image.open(io.BytesIO(response.content)).convert("RGB") # 3. If it's a webpage, hunt for the main Open Graph image elif content_type.startswith('text/html'): print("Webpage detected! Scraping for the main image...") soup = BeautifulSoup(response.text, 'html.parser') og_image = soup.find('meta', property='og:image') if og_image and og_image.get('content'): actual_image_url = og_image['content'] print(f"Found hidden image at: {actual_image_url}") img_response = requests.get(actual_image_url, headers=headers, timeout=10) return Image.open(io.BytesIO(img_response.content)).convert("RGB") else: raise Exception("Could not find a main image on this webpage.") else: raise Exception(f"Unsupported link type: {content_type}") # --- ENDPOINTS --- @app.get("/health") def health_check(): return {"status": "ok"} @app.post("/predict") def predict(req: PredictRequest): try: img = None # 1. Handle URL Input (Using the new Smart Fetcher) if req.image_url: img = get_image_from_any_url(req.image_url) # 2. Handle Base64 Input elif req.image: b64_data = req.image if "," in b64_data: b64_data = b64_data.split(",")[1] image_bytes = base64.b64decode(b64_data) img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # 3. Handle Empty Request else: raise HTTPException(status_code=400, detail="Must provide 'image' (base64) or 'image_url'.") # 4. Execute AI Math results = pipe(img) # 5. Format to exact Task 24 specifications top_pred = max(results, key=lambda x: x['score']) scores_dict = {res['label']: round(res['score'], 4) for res in results} return { "prediction": top_pred['label'], "confidence": round(top_pred['score'], 4), "scores": scores_dict } except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to process X-Ray: {str(e)}")