tampee commited on
Commit
cfd4ead
·
1 Parent(s): 3fc3b69

fix: catch all inference exceptions and OSError in SSRF check to surface real 500 cause

Browse files
Files changed (1) hide show
  1. app/main.py +18 -6
app/main.py CHANGED
@@ -82,10 +82,15 @@ def _validate_url(url: str) -> str:
82
  try:
83
  for info in socket.getaddrinfo(hostname, None):
84
  addr = info[4][0]
85
- ip = ipaddress.ip_address(addr)
86
- if ip.is_private or ip.is_loopback or ip.is_link_local:
87
- raise HTTPException(status_code=400, detail="URL resolves to a private address")
88
- except socket.gaierror as exc:
 
 
 
 
 
89
  raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc
90
  return url
91
 
@@ -93,6 +98,9 @@ def _validate_url(url: str) -> str:
93
  @app.post("/analyze", response_model=PredictResponse, dependencies=[Depends(_require_api_key)])
94
  def analyze(body: AnalyzeRequest) -> PredictResponse:
95
  """Accept a public image URL, download it, and run inference."""
 
 
 
96
  _validate_url(body.image_url)
97
  try:
98
  resp = http_requests.get(body.image_url, timeout=30)
@@ -105,5 +113,9 @@ def analyze(body: AnalyzeRequest) -> PredictResponse:
105
  except UnidentifiedImageError as exc:
106
  raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc
107
 
108
- result = model.predict(image)
109
- return PredictResponse(**result)
 
 
 
 
 
82
  try:
83
  for info in socket.getaddrinfo(hostname, None):
84
  addr = info[4][0]
85
+ try:
86
+ ip = ipaddress.ip_address(addr)
87
+ if ip.is_private or ip.is_loopback or ip.is_link_local:
88
+ raise HTTPException(status_code=400, detail="URL resolves to a private address")
89
+ except ValueError:
90
+ pass # skip unparseable addresses (e.g. scoped IPv6)
91
+ except HTTPException:
92
+ raise
93
+ except OSError as exc:
94
  raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc
95
  return url
96
 
 
98
  @app.post("/analyze", response_model=PredictResponse, dependencies=[Depends(_require_api_key)])
99
  def analyze(body: AnalyzeRequest) -> PredictResponse:
100
  """Accept a public image URL, download it, and run inference."""
101
+ import logging
102
+ logger = logging.getLogger(__name__)
103
+
104
  _validate_url(body.image_url)
105
  try:
106
  resp = http_requests.get(body.image_url, timeout=30)
 
113
  except UnidentifiedImageError as exc:
114
  raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc
115
 
116
+ try:
117
+ result = model.predict(image)
118
+ return PredictResponse(**result)
119
+ except Exception as exc:
120
+ logger.exception("Model inference failed")
121
+ raise HTTPException(status_code=500, detail=f"Model inference error: {type(exc).__name__}: {exc}") from exc