Habeeb Okunade commited on
Commit
cb24c7c
·
1 Parent(s): 472db94

Update Training script

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -16,21 +16,33 @@ model = None
16
 
17
  def load_model():
18
  global processor, model, CLASSES
19
- processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
20
- model = BeitForImageClassification.from_pretrained(MODEL_DIR)
21
- with open(os.path.join(MODEL_DIR, "labels.json")) as f:
22
- CLASSES = json.load(f)
 
 
 
 
 
 
 
23
 
24
  @app.on_event("startup")
25
  def startup_event():
26
  if os.path.exists(MODEL_DIR):
27
  load_model()
 
 
28
 
29
  @app.post("/predict")
30
  async def predict(file: UploadFile):
31
  if model is None:
32
  return {"error": "Model not trained yet"}
33
- img = Image.open(file.file).convert("RGB")
 
 
 
34
  inputs = processor(images=img, return_tensors="pt")
35
  with torch.no_grad():
36
  logits = model(**inputs).logits
@@ -38,7 +50,7 @@ async def predict(file: UploadFile):
38
  pred_id = int(torch.argmax(logits, dim=1).item())
39
  return {
40
  "class_id": CLASSES[pred_id],
41
- "probabilities": [{CLASSES[i]: float(p) for i, p in enumerate(probs)}]
42
  }
43
 
44
  @app.post("/train")
 
16
 
17
  def load_model():
18
  global processor, model, CLASSES
19
+ try:
20
+ processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
21
+ model = BeitForImageClassification.from_pretrained(MODEL_DIR)
22
+ labels_path = os.path.join(MODEL_DIR, "labels.json")
23
+ if os.path.exists(labels_path):
24
+ with open(labels_path) as f:
25
+ CLASSES = json.load(f)
26
+ print("✅ Model and processor loaded successfully")
27
+ except Exception as e:
28
+ processor, model = None, None
29
+ print(f"⚠️ Skipping model load: {e}")
30
 
31
  @app.on_event("startup")
32
  def startup_event():
33
  if os.path.exists(MODEL_DIR):
34
  load_model()
35
+ else:
36
+ print("⚠️ MODEL_DIR not found, skipping model load")
37
 
38
  @app.post("/predict")
39
  async def predict(file: UploadFile):
40
  if model is None:
41
  return {"error": "Model not trained yet"}
42
+ try:
43
+ img = Image.open(file.file).convert("RGB")
44
+ except Exception as e:
45
+ return {"error": f"Invalid image: {str(e)}"}
46
  inputs = processor(images=img, return_tensors="pt")
47
  with torch.no_grad():
48
  logits = model(**inputs).logits
 
50
  pred_id = int(torch.argmax(logits, dim=1).item())
51
  return {
52
  "class_id": CLASSES[pred_id],
53
+ "probabilities": {CLASSES[i]: float(p) for i, p in enumerate(probs)}
54
  }
55
 
56
  @app.post("/train")