Spaces:
Build error
Build error
| import streamlit as st | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # Load processor and model | |
| processor = AutoImageProcessor.from_pretrained("RickyIG/emotion_face_image_classification") | |
| model = AutoModelForImageClassification.from_pretrained("RickyIG/emotion_face_image_classification") | |
| # Title of the Streamlit app | |
| st.title("Emotion Detection App") | |
| # Option to choose between uploading image, video, or using live camera | |
| option = st.radio("Select an option", ("Upload Image", "Upload Video", "Use Live Camera")) | |
| if option == "Upload Image": | |
| # Upload image | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Display the uploaded image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Make predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits # raw model outputs (before softmax) | |
| predicted_class_idx = logits.argmax(-1).item() # predicted class index | |
| # Get the label of the predicted class | |
| label = model.config.id2label[predicted_class_idx] | |
| # Display the result | |
| st.write(f"Predicted Emotion: {label}") | |
| elif option == "Upload Video": | |
| # Upload video file | |
| uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"]) | |
| if uploaded_video is not None: | |
| # Save video to a temporary path | |
| temp_video_path = "/tmp/uploaded_video.mp4" | |
| with open(temp_video_path, "wb") as f: | |
| f.write(uploaded_video.read()) | |
| # Open the video using OpenCV | |
| cap = cv2.VideoCapture(temp_video_path) | |
| if not cap.isOpened(): | |
| st.error("Error: Could not open video.") | |
| else: | |
| stframe = st.empty() # Placeholder to display live video feed | |
| while True: | |
| # Capture frame-by-frame | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert frame (BGR) to RGB (PIL format) | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Make predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| label = model.config.id2label[predicted_class_idx] | |
| # Display the result | |
| cv2.putText(frame, f"Emotion: {label}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA) | |
| # Convert the frame to RGB for Streamlit | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| stframe.image(frame_rgb, channels="RGB", use_container_width=True) | |
| cap.release() | |
| elif option == "Use Live Camera": | |
| cap = cv2.VideoCapture(0) | |
| if not cap.isOpened(): | |
| st.error("Error: Could not open webcam.") | |
| else: | |
| stframe = st.empty() | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Convert frame (BGR) to RGB (PIL format) | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| # Preprocess the image | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Make predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| label = model.config.id2label[predicted_class_idx] | |
| # Display the result | |
| cv2.putText(frame, f"Emotion: {label}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA) | |
| # Convert the frame to RGB for Streamlit | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| stframe.image(frame_rgb, channels="RGB", use_container_width=True) | |
| cap.release() | |
| # Add a dynamic signature with Bastliga font | |
| st.markdown("<br><br><h5 style='text-align: center;'>Developed by M.Nabeel</h5>", unsafe_allow_html=True) | |