Mohansai2004 commited on
Commit
fee3c3b
·
verified ·
1 Parent(s): 3c7e039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -3,24 +3,41 @@ from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import io
5
 
 
 
 
 
 
 
 
 
6
  app = FastAPI()
7
 
8
  @app.get("/")
9
- def read_root():
10
- return {"message": "Space is running. Use POST /process-image"}
11
 
12
- @app.post("/process-image")
13
- async def process_image(file: UploadFile = File(...)):
14
  try:
 
15
  contents = await file.read()
16
- image = Image.open(io.BytesIO(contents))
17
- width, height = image.size
 
 
 
 
 
 
 
 
 
18
 
19
  return JSONResponse(content={
20
  "filename": file.filename,
21
- "width": width,
22
- "height": height,
23
- "format": image.format
24
  })
 
25
  except Exception as e:
26
  return JSONResponse(content={"error": str(e)}, status_code=400)
 
3
  from PIL import Image
4
  import io
5
 
6
+ from transformers import AutoProcessor, AutoModelForImageClassification
7
+ import torch
8
+
9
+ # Load model and processor (only once at startup)
10
+ MODEL_NAME = "google/vit-base-patch16-224"
11
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
13
+
14
  app = FastAPI()
15
 
16
  @app.get("/")
17
+ def root():
18
+ return {"message": "Send an image to POST /analyze-image"}
19
 
20
+ @app.post("/analyze-image")
21
+ async def analyze_image(file: UploadFile = File(...)):
22
  try:
23
+ # Read and convert the uploaded image
24
  contents = await file.read()
25
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
26
+
27
+ # Preprocess
28
+ inputs = processor(images=image, return_tensors="pt")
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+
32
+ # Get top prediction
33
+ logits = outputs.logits
34
+ predicted_class_id = logits.argmax(-1).item()
35
+ label = model.config.id2label[predicted_class_id]
36
 
37
  return JSONResponse(content={
38
  "filename": file.filename,
39
+ "predicted_label": label
 
 
40
  })
41
+
42
  except Exception as e:
43
  return JSONResponse(content={"error": str(e)}, status_code=400)