Dua Rajper commited on
Commit
a707553
·
verified ·
1 Parent(s): dad8e4f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
2
+
3
+ import streamlit as st
4
+ from PIL import Image
5
+ from transformers import AutoModelForImageSegmentation, AutoProcessor
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Title of the app
10
+ st.title("Image Segmentation App with Hugging Face and Streamlit")
11
+
12
+ # Description
13
+ st.write("Upload an image, and the Hugging Face model will segment it.")
14
+
15
+ # Load the Hugging Face model and processor
16
+ @st.cache_resource # Cache the model to avoid reloading every time
17
+ def load_model():
18
+ model_name = "ZhengPeng7/BiRefNet"
19
+ model = AutoModelForImageSegmentation.from_pretrained(model_name, trust_remote_code=True)
20
+ processor = AutoProcessor.from_pretrained(model_name)
21
+ return model, processor
22
+
23
+ model, processor = load_model()
24
+
25
+ # Upload an image
26
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
27
+
28
+ if uploaded_image:
29
+ # Display the uploaded image
30
+ image = Image.open(uploaded_image)
31
+ st.image(image, caption="Uploaded Image", use_column_width=True)
32
+
33
+ # Perform segmentation
34
+ st.write("Performing segmentation... Please wait!")
35
+ inputs = processor(images=image, return_tensors="pt")
36
+ outputs = model(**inputs)
37
+
38
+ # Generate segmentation mask
39
+ segmentation = outputs.logits.argmax(dim=1)[0].detach().cpu().numpy()
40
+
41
+ # Display the segmentation mask
42
+ st.write("Segmentation mask:")
43
+ plt.figure(figsize=(10, 10))
44
+ plt.imshow(segmentation, cmap="viridis")
45
+ plt.axis("off")
46
+ st.pyplot(plt)