sam_avatar / app.py
ammar101's picture
Feature: Export skeleton inference raw data to Pickle file for clothes rigging
9d16fe3
import streamlit as st
import os
from pipeline import AvatarExtractorPipeline
st.set_page_config(page_title="3D Avatar Extractor", layout="wide")
st.title("🧍 3D Avatar Extractor")
st.write("Upload an image and we will extract your 3D avatar (`.obj`).")
@st.cache_resource
def get_pipeline():
st.info("Downloading and initializing SAM-3D-Body...")
return AvatarExtractorPipeline()
try:
pipeline = get_pipeline()
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Failed to load model: {e}")
st.stop()
# Predefined Skin Tones Mapping for Frontend Coloring
skin_tones = {
"Fair": "#FDF1E8",
"Light": "#F3D8C4",
"Medium": "#D5AC8A",
"Tan": "#BB8D6A",
"Coco": "#7E4E30",
"Deep": "#3D2314"
}
with st.sidebar:
st.header("Avatar Settings")
selected_tone_name = st.selectbox("Select Skin Tone Category:", list(skin_tones.keys()), index=2)
selected_color_hex = skin_tones[selected_tone_name]
uploaded_file = st.file_uploader("Upload Full Body Picture", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
temp_dir = "temp_processing"
os.makedirs(temp_dir, exist_ok=True)
temp_img_path = os.path.join(temp_dir, uploaded_file.name)
with open(temp_img_path, "wb") as f:
f.write(uploaded_file.getbuffer())
col1, col2 = st.columns(2)
with col1:
st.header("Input Photo")
st.image(temp_img_path, use_column_width=True)
with col2:
st.header("🧑‍🦲 3D Avatar")
with st.spinner("Extracting 3D Avatar..."):
try:
mesh_path, pkl_path = pipeline.extract_avatar(temp_img_path)
# Render 3D Model Interactively
import trimesh
import plotly.graph_objects as go
mesh = trimesh.load(mesh_path)
vertices = mesh.vertices
faces = mesh.faces
fig = go.Figure(data=[go.Mesh3d(
x=vertices[:, 0],
y=vertices[:, 2], # Native depth mapped to Plotly Y
z=-vertices[:, 1], # Native downward Y inverted and mapped to Plotly Up-Z
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],
color=selected_color_hex,
opacity=1.0,
lighting=dict(ambient=0.4, diffuse=0.8, specular=0.2, roughness=0.5)
)])
fig.update_layout(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
aspectmode='data'
),
margin=dict(l=0, r=0, b=0, t=0),
height=500
)
st.plotly_chart(fig, use_container_width=True)
col_dl1, col_dl2 = st.columns(2)
with col_dl1:
with open(mesh_path, "rb") as f:
st.download_button("Download 3D Model (.obj)", data=f, file_name="avatar.obj", mime="application/octet-stream", use_container_width=True)
with col_dl2:
with open(pkl_path, "rb") as f:
st.download_button("Download Skeleton Data (.pkl)", data=f, file_name="avatar_skeleton.pkl", mime="application/octet-stream", use_container_width=True)
st.success("Avatar Extraction Complete!")
except Exception as e:
st.error(f"Error processing image: {e}")