jacopo22295 commited on
Commit
03adcde
·
verified ·
1 Parent(s): a071556

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -35
  2. README.md +7 -11
  3. app.py +97 -0
  4. classes.json +11 -0
  5. requirements.txt +6 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,12 +1,8 @@
1
- ---
2
- title: Corrobotv2
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.44.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ # Corrosion Classifier (ViT-B/16 · CPU)
2
+ Space Gradio pronto per Zero GPU. Carica `vit_b16_best.pth` nella root.
 
 
 
 
 
 
 
 
3
 
4
+ ## Uso
5
+ 1. Crea uno Space Gradio su Hugging Face.
6
+ 2. Carica: app.py, requirements.txt, classes.json, vit_b16_best.pth.
7
+ 3. Runtime: CPU.
8
+ 4. Avvia.
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ MODEL_WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", "vit_b16_best.pth")
11
+ CLASSES_PATH = os.getenv("CLASSES_PATH", "classes.json")
12
+ IMAGE_SIZE = 224
13
+
14
+ def load_classes(path: str):
15
+ with open(path, "r", encoding="utf-8") as f:
16
+ return json.load(f)
17
+
18
+ def build_transforms(img_size=IMAGE_SIZE):
19
+ return T.Compose([
20
+ T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BICUBIC),
21
+ T.ToTensor(),
22
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
+ ])
24
+
25
+ def create_model(num_classes: int):
26
+ import timm
27
+ model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_classes)
28
+ return model
29
+
30
+ def load_model(weights_path: str, classes):
31
+ device = torch.device("cpu")
32
+ model = create_model(num_classes=len(classes))
33
+ state = torch.load(weights_path, map_location=device)
34
+ if isinstance(state, dict) and "state_dict" in state:
35
+ state = state["state_dict"]
36
+ cleaned = {k.replace("module.", "").replace("model.", ""): v for k, v in state.items()}
37
+ model.load_state_dict(cleaned, strict=False)
38
+ model.eval()
39
+ return model
40
+
41
+ _model = None
42
+ _classes = None
43
+ _tfm = None
44
+
45
+ def setup():
46
+ global _model, _classes, _tfm
47
+ if _classes is None:
48
+ _classes = load_classes(CLASSES_PATH)
49
+ if _tfm is None:
50
+ _tfm = build_transforms()
51
+ if _model is None:
52
+ if not os.path.exists(MODEL_WEIGHTS):
53
+ raise FileNotFoundError(f"File pesi non trovato: {MODEL_WEIGHTS}")
54
+ _model = load_model(MODEL_WEIGHTS, _classes)
55
+
56
+ @torch.inference_mode()
57
+ def predict(image: Image.Image):
58
+ setup()
59
+ if image.mode != "RGB":
60
+ image = image.convert("RGB")
61
+ x = _tfm(image).unsqueeze(0)
62
+ logits = _model(x)
63
+ probs = torch.softmax(logits, dim=1).cpu().numpy().squeeze(0)
64
+ top_idx = np.argsort(-probs)[:3]
65
+ top_labels = [_classes[i] for i in top_idx]
66
+ top_scores = [float(probs[i]) for i in top_idx]
67
+ pred_label = top_labels[0]
68
+ pred_conf = top_scores[0]
69
+ result = {
70
+ "prediction": pred_label,
71
+ "confidence": round(pred_conf * 100, 2),
72
+ "top3": [
73
+ {"label": top_labels[j], "confidence": round(top_scores[j] * 100, 2)}
74
+ for j in range(3)
75
+ ]
76
+ }
77
+ human = f"Tipo: {pred_label} — Affidabilità: {result['confidence']}%"
78
+ return human, result
79
+
80
+ title = "Corrosion Classifier (ViT-B/16 • CPU)"
81
+ description = "Carica o scatta una foto del pezzo corroso. Predizione su CPU."
82
+
83
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
+ gr.Markdown(f"# {title}")
85
+ gr.Markdown(description)
86
+ with gr.Row():
87
+ with gr.Column():
88
+ inp = gr.Image(label="Immagine", type="pil", sources=["upload", "camera"], image_mode="RGB")
89
+ analyze_btn = gr.Button("Analizza immagine", variant="primary")
90
+ with gr.Column():
91
+ out_text = gr.Textbox(label="Risultato", interactive=False)
92
+ out_json = gr.JSON(label="Dettagli (top-3)")
93
+ analyze_btn.click(fn=predict, inputs=[inp], outputs=[out_text, out_json])
94
+ gr.Examples(examples=[], inputs=[inp])
95
+
96
+ if __name__ == "__main__":
97
+ demo.launch(server_name="0.0.0.0", server_port=7860)
classes.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "crevice_corrosion",
3
+ "erosion_corrosion",
4
+ "galvanic_corrosion",
5
+ "mic_corrosion",
6
+ "no_corrosion",
7
+ "pitting_corrosion",
8
+ "stress_corrosion",
9
+ "under_insulation_corrosion",
10
+ "uniform_corrosion"
11
+ ]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ torch>=2.3.0
3
+ torchvision>=0.18.0
4
+ timm>=0.9.12
5
+ pillow>=10.3.0
6
+ numpy>=1.26.4