Asimuddin11 commited on
Commit
399b72d
·
verified ·
1 Parent(s): f6766d8

UPDATE app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -14
app.py CHANGED
@@ -1,26 +1,76 @@
1
-
2
  import streamlit as st
3
  from transformers import pipeline
4
  from PIL import Image
 
 
 
 
 
 
 
 
5
 
6
- st.set_page_config(page_title="ViT Image Classifier")
7
- st.title("ViT Image Classification")
8
 
9
  @st.cache_resource
10
  def load_model():
11
- return pipeline("image-classification", model="google/vit-base-patch16-224")
 
 
 
 
 
 
 
 
 
12
 
13
- pipe = load_model()
 
14
 
15
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
 
 
 
 
 
16
 
17
  if uploaded_file is not None:
18
- image = Image.open(uploaded_file).convert("RGB")
19
- st.image(image, caption="Uploaded Image", use_column_width=True)
20
-
21
- with st.spinner("Classifying..."):
22
- preds = pipe(image)
 
 
 
 
 
 
23
 
24
- st.subheader("Predictions")
25
- for i, pred in enumerate(preds):
26
- st.write(f"{i+1}. {pred['label']} ({pred['score']:.3f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import os
5
+
6
+ # Configure Streamlit for Hugging Face Spaces
7
+ st.set_page_config(
8
+ page_title="ViT Image Classifier",
9
+ page_icon="🖼️",
10
+ layout="centered"
11
+ )
12
 
13
+ st.title("🖼️ ViT Image Classification")
14
+ st.markdown("Upload an image to classify it using Google's Vision Transformer model.")
15
 
16
  @st.cache_resource
17
  def load_model():
18
+ """Load the ViT model with error handling."""
19
+ try:
20
+ return pipeline("image-classification", model="google/vit-base-patch16-224")
21
+ except Exception as e:
22
+ st.error(f"Error loading model: {str(e)}")
23
+ return None
24
+
25
+ # Initialize the model
26
+ with st.spinner("Loading model..."):
27
+ pipe = load_model()
28
 
29
+ if pipe is None:
30
+ st.stop()
31
 
32
+ # File uploader
33
+ uploaded_file = st.file_uploader(
34
+ "Choose an image file",
35
+ type=["jpg", "jpeg", "png", "bmp", "tiff"],
36
+ help="Upload an image in JPG, PNG, BMP, or TIFF format"
37
+ )
38
 
39
  if uploaded_file is not None:
40
+ try:
41
+ # Display the uploaded image
42
+ image = Image.open(uploaded_file).convert("RGB")
43
+ st.image(image, caption="Uploaded Image", use_column_width=True)
44
+
45
+ # Perform classification
46
+ with st.spinner("Classifying image..."):
47
+ preds = pipe(image)
48
+
49
+ # Display results
50
+ st.subheader("🎯 Classification Results")
51
 
52
+ # Create columns for better layout
53
+ col1, col2 = st.columns(2)
54
+
55
+ with col1:
56
+ st.metric("Top Prediction", preds[0]['label'])
57
+
58
+ with col2:
59
+ st.metric("Confidence", f"{preds[0]['score']:.1%}")
60
+
61
+ # Show all predictions
62
+ st.subheader("📊 All Predictions")
63
+ for i, pred in enumerate(preds):
64
+ confidence = pred['score']
65
+ st.progress(confidence)
66
+ st.write(f"**{i+1}. {pred['label']}** - {confidence:.1%}")
67
+
68
+ except Exception as e:
69
+ st.error(f"Error processing image: {str(e)}")
70
+
71
+ else:
72
+ st.info("👆 Please upload an image to get started!")
73
+
74
+ # Add footer
75
+ st.markdown("---")
76
+ st.markdown("Built with Streamlit and 🤗 Transformers")