shimaa22 commited on
Commit
131d395
·
verified ·
1 Parent(s): f487755

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -1,32 +1,33 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from ultralytics import YOLO
4
  import torch
5
 
6
- # ------------------------
7
- # تحميل الموديل التصنيفي
8
- # ------------------------
9
  MODEL_PATH = "best.pt"
10
- model = YOLO(MODEL_PATH) # موديل تصنيف من Ultralytics
 
 
11
  CLASS_NAMES = ["ax", "co", "sa"]
12
 
13
  # ------------------------
14
  # prediction
15
  # ------------------------
16
  def predict_orientation(image: Image.Image):
17
- # نستخدم predict على صورة PIL
18
- results = model.predict(source=image, imgsz=224, device="cpu", verbose=False)
19
-
20
- # Ultralytics ClassificationModel بيرجع probs لكل class
21
- # التأكد من وجود probs
22
- if hasattr(results[0], "probs"):
23
- probs = results[0].probs # tensor
24
- pred = torch.argmax(probs)
 
 
 
25
  orientation = CLASS_NAMES[pred.item()]
26
- confidence = round(probs[pred].item(), 2)
27
- return f"Orientation: {orientation} | Confidence: {confidence}"
28
- else:
29
- return "Error: Could not get probabilities from model."
30
 
31
  # ------------------------
32
  # Gradio Interface
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from ultralytics.nn.tasks import ClassificationModel
4
  import torch
5
 
 
 
 
6
  MODEL_PATH = "best.pt"
7
+ model = ClassificationModel(MODEL_PATH)
8
+ model.eval()
9
+
10
  CLASS_NAMES = ["ax", "co", "sa"]
11
 
12
  # ------------------------
13
  # prediction
14
  # ------------------------
15
  def predict_orientation(image: Image.Image):
16
+ import torchvision.transforms as transforms
17
+ transform = transforms.Compose([
18
+ transforms.Resize((224,224)),
19
+ transforms.ToTensor(),
20
+ ])
21
+ img_tensor = transform(image).unsqueeze(0) # batch dimension
22
+
23
+ with torch.no_grad():
24
+ outputs = model(img_tensor)
25
+ probs = torch.softmax(outputs, dim=1)
26
+ conf, pred = torch.max(probs, 1)
27
  orientation = CLASS_NAMES[pred.item()]
28
+ confidence = round(conf.item(), 2)
29
+
30
+ return f"Orientation: {orientation} | Confidence: {confidence}"
 
31
 
32
  # ------------------------
33
  # Gradio Interface