Alic22 commited on
Commit
fa32eee
·
verified ·
1 Parent(s): cf800f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -104
app.py CHANGED
@@ -1,107 +1,125 @@
 
1
  import torch
2
- import numpy as np
3
  from torchvision import transforms
4
- from datasets import load_dataset
5
- from transformers import (
6
- SegformerForSemanticSegmentation,
7
- SegformerFeatureExtractor,
8
- Trainer,
9
- TrainingArguments
10
- )
11
  import evaluate
12
-
13
- # ------------------------------
14
- # 1️⃣ Parameter
15
- # ------------------------------
16
- DATA_DIR = "path_to_dataset"
17
- NUM_CLASSES = 3 # z.B. 3 Klassen: Hintergrund, Schaden, Rand
18
- IMAGE_SIZE = 256 # Bildgröße für Training
19
-
20
- # ------------------------------
21
- # 2️⃣ Dataset laden
22
- # ------------------------------
23
- # Annahme: Dataset im ImageFolder Format mit Unterordnern 'train' und 'validation'
24
- dataset = load_dataset("imagefolder", data_dir=DATA_DIR)
25
-
26
- # Transformationen für Bilder
27
- train_transforms = transforms.Compose([
28
- transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
29
- transforms.ToTensor(),
30
- ])
31
-
32
- mask_transforms = transforms.Compose([
33
- transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
34
- transforms.PILToTensor(), # Masken als Tensor
35
- ])
36
-
37
- # Preprocessing-Funktion
38
- def preprocess(example):
39
- example["pixel_values"] = train_transforms(example["image"])
40
- # Masken als LongTensor für CrossEntropyLoss
41
- example["labeks"] = mask_transforms(example["label"]).long().squeeze(0)
42
- return example
43
-
44
- dataset = dataset.map(preprocess)
45
-
46
- # ------------------------------
47
- # 3️⃣ Feature Extractor & Modell
48
- # ------------------------------
49
- feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b1")
50
-
51
- model = SegformerForSemanticSegmentation.from_pretrained(
52
- "nvidia/mit-b1",
53
- num_labels=NUM_CLASSES,
54
- )
55
-
56
- # ------------------------------
57
- # 4️⃣ Metrics
58
- # ------------------------------
59
- metric = evaluate.load("mean_iou")
60
-
61
- def compute_metrics(p):
62
- preds = np.argmax(p.predictions, axis=1)
63
- return metric.compute(predictions=preds, references=p.label_ids, num_labels=NUM_CLASSES)
64
-
65
- # ------------------------------
66
- # 5️⃣ TrainingArguments
67
- # ------------------------------
68
- training_args = TrainingArguments(
69
- output_dir="./results",
70
- per_device_train_batch_size=4,
71
- per_device_eval_batch_size=4,
72
- num_train_epochs=10,
73
- learning_rate=5e-5,
74
- evaluation_strategy="steps",
75
- save_strategy="steps",
76
- save_steps=200,
77
- eval_steps=200,
78
- logging_steps=50,
79
- fp16=True, # Mixed Precision, falls GPU verfügbar
80
- remove_unused_columns=False, # wichtig für Segmentation
81
- )
82
-
83
- # ------------------------------
84
- # 6️⃣ Trainer
85
- # ------------------------------
86
- trainer = Trainer(
87
- model=model,
88
- args=training_args,
89
- train_dataset=dataset["train"],
90
- eval_dataset=dataset["validation"],
91
- compute_metrics=compute_metrics,
92
- )
93
-
94
- # ------------------------------
95
- # 7️⃣ Training starten
96
- # ------------------------------
97
- trainer.train()
98
-
99
- # ------------------------------
100
- # 8️⃣ Modell speichern
101
- # ------------------------------
102
- trainer.save_model("my_segformer_model")
103
- feature_extractor.save_pretrained("my_segformer_model")
104
-
105
- print("✅ Training abgeschlossen und Modell gespeichert!")
106
-
107
- n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
+ from PIL import Image
4
  from torchvision import transforms
5
+ import numpy as np
6
+ from matplotlib import pyplot as plt
7
+ from torch.utils.data import Dataset, DataLoader
 
 
 
 
8
  import evaluate
9
+ from torch import nn
10
+ from transformers import SegformerForSemanticSegmentation
11
+ import sys
12
+ import io
13
+
14
+
15
+ ###################
16
+ # Setup label names
17
+ target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
18
+ 'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
19
+ 'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor']
20
+ classes, nclasses = target_list, len(target_list)
21
+ label2id = dict(zip(classes, range(nclasses)))
22
+ id2label = dict(zip(range(nclasses), classes))
23
+
24
+ ############
25
+ # Load model
26
+ device = torch.device('cpu')
27
+ segformer = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b1",
28
+ id2label=id2label,
29
+ label2id=label2id)
30
+
31
+ # SegModel
32
+ class SegModel(nn.Module):
33
+ def __init__(self, segformer):
34
+ super(SegModel, self).__init__()
35
+ self.segformer = segformer
36
+ self.upsample = nn.Upsample(scale_factor=4, mode='nearest')
37
+
38
+ def forward(self, x):
39
+ return self.upsample(self.segformer(x).logits)
40
+
41
+ model = SegModel(segformer)
42
+ path = "runs/2023-08-31_rich-paper-12/best_model_cpu.pth"
43
+ print(f"Load Segformer weights from {path}")
44
+ #model = model.load_state_dict(torch.load(path, map_location=device))
45
+ model = torch.load(path)
46
+ model.eval()
47
+
48
+ ##################
49
+ # Image preprocess
50
+ ##################
51
+
52
+ to_tensor = transforms.ToTensor()
53
+ resize = transforms.Resize((512, 512))
54
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
55
+ std=[0.229, 0.224, 0.225])
56
+
57
+ def process_pil(img):
58
+ img = to_tensor(img)
59
+ img = resize(img)
60
+ img = normalize(img)
61
+ return img
62
+
63
+ ###########
64
+ # Inference
65
+
66
+ def inference(img, name):
67
+ img = process_pil(img)
68
+ mask = model(img.unsqueeze(0)) # we need a batch, hence we introduce an extra dimenation at position 0 (unsqueeze)
69
+ mask = mask[0]
70
+
71
+ # Get probability values (logits to probs)
72
+ mask_probs = torch.sigmoid(mask)
73
+ mask_probs = mask_probs.detach().numpy()
74
+ mask_probs.shape
75
+
76
+ # Make binary mask
77
+ THRESHOLD = 0.5
78
+ mask_preds = mask_probs > THRESHOLD
79
+
80
+ # All combined
81
+ mask_all = mask_preds.sum(axis=0)
82
+ mask_all = np.expand_dims(mask_all, axis=0)
83
+ mask_all.shape
84
+
85
+ # Concat all combined with normal preds
86
+ mask_preds = np.concatenate((mask_all, mask_preds),axis=0)
87
+ labs = ["ALL"] + target_list
88
+
89
+ fig, axes = plt.subplots(5, 4, figsize = (10,10))
90
+
91
+ for i, ax in enumerate(axes.flat):
92
+ label = labs[i]
93
+ ax.imshow(mask_preds[i])
94
+ ax.set_title(label)
95
+
96
+ plt.tight_layout()
97
+
98
+
99
+ # plt to PIL
100
+ img_buf = io.BytesIO()
101
+ fig.savefig(img_buf, format='png')
102
+ im = Image.open(img_buf)
103
+ return im
104
+
105
+
106
+
107
+ title = "Masterarbeit"
108
+ description = """
109
+
110
+
111
+ """
112
+
113
+ article = "<p style='text-align: center'><a href='https://github.com/phiyodr/dacl10k-toolkit' target='_blank'>Github Repo</a></p>"
114
+ examples=[['assets/dacl10k_v2_validation_0037.jpg', 'dacl10k_v2_validation_0037.jpg'],['assets/dacl10k_v2_validation_0068.jpg','dacl10k_v2_validation_0068.jpg'], ['assets/dacl10k_v2_validation_0053.jpg', 'dacl10k_v2_validation_0053.jpg']]
115
+
116
+ demo = gr.Interface(
117
+ fn=inference,
118
+ inputs=gr.inputs.Image(type="pil"),
119
+ outputs=gr.outputs.Image(type="pil"),
120
+ title=title,
121
+ description=description,
122
+ article=article,
123
+ examples=examples)
124
+
125
+ demo.launch()