rjsmoke00 commited on
Commit
99e2cbf
Β·
verified Β·
1 Parent(s): bbfee05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -85
app.py CHANGED
@@ -1,85 +1 @@
1
- import os
2
- import zipfile
3
- from pathlib import Path
4
- from io import BytesIO
5
-
6
- import torch
7
- import uvicorn
8
- from fastapi import FastAPI, UploadFile, File
9
- from fastapi.responses import JSONResponse
10
- from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor
11
- from PIL import Image
12
-
13
- # ================== Config ==================
14
- MODEL_ZIP = "cattle_model.zip"
15
- MODEL_DIR = Path(__file__).parent / "cattle_model"
16
-
17
- # ================== Unzip if needed ==================
18
- if not MODEL_DIR.exists():
19
- print(f"πŸ“¦ Extracting {MODEL_ZIP}...")
20
- with zipfile.ZipFile(MODEL_ZIP, 'r') as zip_ref:
21
- zip_ref.extractall(MODEL_DIR.parent)
22
- print("βœ… Extraction complete")
23
- else:
24
- print("πŸ“‚ Model folder already exists, skipping extraction")
25
-
26
- # ================== Load Model ==================
27
- print("πŸ“¦ Loading model...")
28
-
29
- # Load config first to avoid HF hub errors
30
- config = AutoConfig.from_pretrained(str(MODEL_DIR), local_files_only=True)
31
-
32
- # Load model with local config
33
- model = AutoModelForImageClassification.from_pretrained(
34
- str(MODEL_DIR),
35
- config=config,
36
- local_files_only=True
37
- )
38
-
39
- # Load image processor
40
- processor = AutoImageProcessor.from_pretrained(
41
- str(MODEL_DIR),
42
- local_files_only=True
43
- )
44
-
45
- model.eval()
46
- print("βœ… Model loaded successfully")
47
-
48
- # ================== FastAPI App ==================
49
- app = FastAPI(title="Cattle/Buffalo Classifier API")
50
-
51
- @app.get("/")
52
- def home():
53
- return {"message": "πŸš€ Cattle/Buffalo Model API is running!"}
54
-
55
- @app.get("/labels")
56
- def get_labels():
57
- """Return all possible class labels."""
58
- return {"labels": list(model.config.id2label.values())}
59
-
60
- @app.post("/predict")
61
- async def predict(file: UploadFile = File(...)):
62
- try:
63
- # Read uploaded image
64
- image = Image.open(BytesIO(await file.read())).convert("RGB")
65
-
66
- # Preprocess
67
- inputs = processor(images=image, return_tensors="pt")
68
-
69
- # Forward pass
70
- with torch.no_grad():
71
- outputs = model(**inputs)
72
- logits = outputs.logits
73
- predicted_class_idx = logits.argmax(-1).item()
74
-
75
- # Map to label
76
- label = model.config.id2label[predicted_class_idx]
77
-
78
- return JSONResponse(content={"prediction": label})
79
-
80
- except Exception as e:
81
- return JSONResponse(content={"error": str(e)}, status_code=500)
82
-
83
- # ================== Run Locally ==================
84
- if __name__ == "__main__":
85
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ ls /home/user/app/cattle_model