Mohansai2004 commited on
Commit
6a82842
·
verified ·
1 Parent(s): fee3c3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -2,42 +2,35 @@ from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import io
5
-
6
- from transformers import AutoProcessor, AutoModelForImageClassification
7
  import torch
 
 
 
 
 
 
8
 
9
- # Load model and processor (only once at startup)
10
  MODEL_NAME = "google/vit-base-patch16-224"
11
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
12
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
13
 
14
  app = FastAPI()
15
 
16
- @app.get("/")
17
- def root():
18
- return {"message": "Send an image to POST /analyze-image"}
19
-
20
  @app.post("/analyze-image")
21
  async def analyze_image(file: UploadFile = File(...)):
22
  try:
23
- # Read and convert the uploaded image
24
  contents = await file.read()
25
  image = Image.open(io.BytesIO(contents)).convert("RGB")
26
 
27
- # Preprocess
28
- inputs = processor(images=image, return_tensors="pt")
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
 
32
- # Get top prediction
33
  logits = outputs.logits
34
  predicted_class_id = logits.argmax(-1).item()
35
  label = model.config.id2label[predicted_class_id]
36
 
37
- return JSONResponse(content={
38
- "filename": file.filename,
39
- "predicted_label": label
40
- })
41
 
42
  except Exception as e:
43
  return JSONResponse(content={"error": str(e)}, status_code=400)
 
2
  from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  import io
 
 
5
  import torch
6
+ import os
7
+
8
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
+
10
+ # Fix permissions by setting cache location
11
+ os.environ["HF_HOME"] = "/app/hf_cache"
12
 
 
13
  MODEL_NAME = "google/vit-base-patch16-224"
14
+ extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
15
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
16
 
17
  app = FastAPI()
18
 
 
 
 
 
19
  @app.post("/analyze-image")
20
  async def analyze_image(file: UploadFile = File(...)):
21
  try:
 
22
  contents = await file.read()
23
  image = Image.open(io.BytesIO(contents)).convert("RGB")
24
 
25
+ inputs = extractor(images=image, return_tensors="pt")
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
 
 
29
  logits = outputs.logits
30
  predicted_class_id = logits.argmax(-1).item()
31
  label = model.config.id2label[predicted_class_id]
32
 
33
+ return JSONResponse(content={"filename": file.filename, "predicted_label": label})
 
 
 
34
 
35
  except Exception as e:
36
  return JSONResponse(content={"error": str(e)}, status_code=400)