File size: 3,823 Bytes
c715811
 
 
 
 
a707ddc
c715811
 
f30e238
c715811
 
 
f30e238
c715811
 
 
 
 
 
 
 
f30e238
c715811
 
 
 
f30e238
c715811
f30e238
 
 
a707ddc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c715811
a707ddc
c715811
 
 
 
 
 
 
f30e238
c715811
a707ddc
f30e238
a707ddc
f30e238
 
 
 
 
 
 
 
 
 
 
 
 
 
c715811
 
f30e238
c715811
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from PIL import Image
from bs4 import BeautifulSoup
import base64
import io
import requests

app = FastAPI(title="STOA Chest X-Ray API")

# --- CORS ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- MODEL LOADING ---
print("Booting Pulmonology Agent. Loading ViT model into memory...")
pipe = pipeline("image-classification", model="dima806/chest_xray_pneumonia_detection")
print("Agent Ready!")

# --- REQUEST SCHEMA ---
class PredictRequest(BaseModel):
    image: str | None = None
    image_url: str | None = None

# --- SMART FETCHER HELPER ---
def get_image_from_any_url(url: str):
    """Smart fetcher that handles both raw images and HTML webpages."""
    headers = {
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
        "Referer": "https://google.com"
    }
    
    # 1. Fetch whatever is at the URL
    response = requests.get(url, headers=headers, timeout=10)
    if response.status_code != 200:
        raise Exception(f"Site blocked us (Error {response.status_code})")
        
    content_type = response.headers.get('Content-Type', '').lower()
    
    # 2. If it's already an image, just return it
    if content_type.startswith('image/'):
        return Image.open(io.BytesIO(response.content)).convert("RGB")
        
    # 3. If it's a webpage, hunt for the main Open Graph image
    elif content_type.startswith('text/html'):
        print("Webpage detected! Scraping for the main image...")
        soup = BeautifulSoup(response.text, 'html.parser')
        og_image = soup.find('meta', property='og:image')
        
        if og_image and og_image.get('content'):
            actual_image_url = og_image['content']
            print(f"Found hidden image at: {actual_image_url}")
            
            img_response = requests.get(actual_image_url, headers=headers, timeout=10)
            return Image.open(io.BytesIO(img_response.content)).convert("RGB")
        else:
            raise Exception("Could not find a main image on this webpage.")
            
    else:
        raise Exception(f"Unsupported link type: {content_type}")

# --- ENDPOINTS ---
@app.get("/health")
def health_check():
    return {"status": "ok"}

@app.post("/predict")
def predict(req: PredictRequest):
    try:
        img = None
        
        # 1. Handle URL Input (Using the new Smart Fetcher)
        if req.image_url:
            img = get_image_from_any_url(req.image_url)
            
        # 2. Handle Base64 Input
        elif req.image:
            b64_data = req.image
            if "," in b64_data:
                b64_data = b64_data.split(",")[1]
            image_bytes = base64.b64decode(b64_data)
            img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
        # 3. Handle Empty Request
        else:
            raise HTTPException(status_code=400, detail="Must provide 'image' (base64) or 'image_url'.")

        # 4. Execute AI Math
        results = pipe(img)
        
        # 5. Format to exact Task 24 specifications
        top_pred = max(results, key=lambda x: x['score'])
        scores_dict = {res['label']: round(res['score'], 4) for res in results}
        
        return {
            "prediction": top_pred['label'],
            "confidence": round(top_pred['score'], 4),
            "scores": scores_dict
        }
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to process X-Ray: {str(e)}")