rjsmoke00 commited on
Commit
bbfee05
Β·
verified Β·
1 Parent(s): 502ade4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import uvicorn
8
  from fastapi import FastAPI, UploadFile, File
9
  from fastapi.responses import JSONResponse
10
- from transformers import AutoModelForImageClassification, AutoImageProcessor
11
  from PIL import Image
12
 
13
  # ================== Config ==================
@@ -25,24 +25,38 @@ else:
25
 
26
  # ================== Load Model ==================
27
  print("πŸ“¦ Loading model...")
 
 
 
 
 
28
  model = AutoModelForImageClassification.from_pretrained(
29
  str(MODEL_DIR),
 
30
  local_files_only=True
31
  )
 
 
32
  processor = AutoImageProcessor.from_pretrained(
33
  str(MODEL_DIR),
34
  local_files_only=True
35
  )
 
36
  model.eval()
37
  print("βœ… Model loaded successfully")
38
 
39
  # ================== FastAPI App ==================
40
- app = FastAPI()
41
 
42
  @app.get("/")
43
  def home():
44
  return {"message": "πŸš€ Cattle/Buffalo Model API is running!"}
45
 
 
 
 
 
 
46
  @app.post("/predict")
47
  async def predict(file: UploadFile = File(...)):
48
  try:
@@ -52,12 +66,13 @@ async def predict(file: UploadFile = File(...)):
52
  # Preprocess
53
  inputs = processor(images=image, return_tensors="pt")
54
 
 
55
  with torch.no_grad():
56
  outputs = model(**inputs)
57
  logits = outputs.logits
58
  predicted_class_idx = logits.argmax(-1).item()
59
 
60
- # Map to labels
61
  label = model.config.id2label[predicted_class_idx]
62
 
63
  return JSONResponse(content={"prediction": label})
 
7
  import uvicorn
8
  from fastapi import FastAPI, UploadFile, File
9
  from fastapi.responses import JSONResponse
10
+ from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor
11
  from PIL import Image
12
 
13
  # ================== Config ==================
 
25
 
26
  # ================== Load Model ==================
27
  print("πŸ“¦ Loading model...")
28
+
29
+ # Load config first to avoid HF hub errors
30
+ config = AutoConfig.from_pretrained(str(MODEL_DIR), local_files_only=True)
31
+
32
+ # Load model with local config
33
  model = AutoModelForImageClassification.from_pretrained(
34
  str(MODEL_DIR),
35
+ config=config,
36
  local_files_only=True
37
  )
38
+
39
+ # Load image processor
40
  processor = AutoImageProcessor.from_pretrained(
41
  str(MODEL_DIR),
42
  local_files_only=True
43
  )
44
+
45
  model.eval()
46
  print("βœ… Model loaded successfully")
47
 
48
  # ================== FastAPI App ==================
49
+ app = FastAPI(title="Cattle/Buffalo Classifier API")
50
 
51
  @app.get("/")
52
  def home():
53
  return {"message": "πŸš€ Cattle/Buffalo Model API is running!"}
54
 
55
+ @app.get("/labels")
56
+ def get_labels():
57
+ """Return all possible class labels."""
58
+ return {"labels": list(model.config.id2label.values())}
59
+
60
  @app.post("/predict")
61
  async def predict(file: UploadFile = File(...)):
62
  try:
 
66
  # Preprocess
67
  inputs = processor(images=image, return_tensors="pt")
68
 
69
+ # Forward pass
70
  with torch.no_grad():
71
  outputs = model(**inputs)
72
  logits = outputs.logits
73
  predicted_class_idx = logits.argmax(-1).item()
74
 
75
+ # Map to label
76
  label = model.config.id2label[predicted_class_idx]
77
 
78
  return JSONResponse(content={"prediction": label})