shubham142000's picture
Update app.py
59af5c4 verified
import streamlit as st
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cosine
import joblib
# Load a pre-trained model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Function to get embedding
def get_embedding(text):
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
# Function to classify text using cosine similarity
def classify_text_cosine(embedding, mean_embeddings, threshold=0.5):
distances = {label: cosine(embedding, mean_embedding) for label, mean_embedding in mean_embeddings.items()}
min_distance = min(distances.values())
if min_distance > threshold:
return "neither"
predicted_label = min(distances, key=distances.get)
return predicted_label
# Function to classify text using MLP model
def classify_text_mlp(embedding, mlp_model):
prediction = mlp_model.predict([embedding])[0]
return list(label_mapping.keys())[prediction]
# Streamlit app
st.title('Biryani, Pizza, Milk, Pasta, Potatos, Tomato, or Neither Classifier')
# Load the embeddings and labels DataFrame
df = pd.read_csv("embeddings_receipes_final.csv")
# Map labels to integers
label_mapping = {'pizza': 0, 'biryani': 1, 'milk': 2, 'pasta': 3, 'potatos': 4, 'tomato': 5, 'neither': 6}
df['label_int'] = df['label'].map(label_mapping)
# Calculate mean embeddings for each class
embeddings = df.iloc[:, 1:-2]
labels = df['label']
mean_embeddings = {label: embeddings[labels == label].mean(axis=0) for label in label_mapping.keys() if label != 'neither'}
# Load the MLP model
try:
mlp_model = joblib.load("mlp_model2.joblib")
except Exception as e:
st.error(f"Error loading MLP model: {e}")
mlp_model = None
# Check if the DataFrame is loaded correctly
if df.shape[1] < 386: # 384 embeddings + 1 label + 1 recipe_id + 1 label_int
st.error(f"Expected DataFrame with 386 columns, but got less than that. Please check your CSV file.")
else:
# Select classification method
classification_method = st.selectbox("Select classification method", ["Cosine Similarity", "MLP Model"])
# Input text
input_text = st.text_area("Enter text to classify")
if st.button("Classify"):
if input_text:
# Get the embedding for the input text
embedding = get_embedding(input_text)
# Ensure the embedding is of the correct dimension
if embedding.shape[0] != 384:
st.error(f"Expected embedding of dimension 384, but got {embedding.shape[0]}.")
else:
# Classify the input text using the selected method
if classification_method == "Cosine Similarity":
predicted_label = classify_text_cosine(embedding, mean_embeddings)
elif mlp_model is not None:
predicted_label = classify_text_mlp(embedding, mlp_model)
else:
st.error("MLP model is not available.")
predicted_label = "unknown"
# Display the result
st.write(f"The predicted label is: **{predicted_label}**")
# Map predicted label to corresponding image
image_mapping = {
'pizza': 'pizza.jpg',
'biryani': 'biryani.jpg',
'milk': 'milk.jpg',
'pasta': 'pasta.jpg',
'potatos': 'potatos.jpg',
'tomato': 'tomato.jpg',
'neither': 'other.jpg'
}
st.image(image_mapping[predicted_label], caption=f"Predicted Label: {predicted_label}", use_column_width=True)
else:
st.write("Please enter text to classify.")
# # Footer
# st.markdown(
# """
# <style>
# .footer {
# position: fixed;
# bottom: 0;
# width: 100%;
# text-align: center;
# padding: 10px;
# background-color: #f1f1f1;
# color: black;
# }
# </style>
# <div class="footer">
# <b>&copy; Shubham Kale and Prof Ganesh Baglar, IIIT Delhi</b>
# </div>
# """,
# unsafe_allow_html=True
# )
# Add a footer
st.markdown(
"""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: #f1f1f1;
color: black;
text-align: center;
padding: 10px;
}
.footer p {
font-size: 1.2em;
font-weight: bold;
}
</style>
<div class="footer">
<p>© Shubham Kale and Prof.Ganesh Bagler, IIIT Delhi</p>
</div>
""", unsafe_allow_html=True
)