predict / app /utils /prediction.py
Maulidaaa's picture
Update app/utils/prediction.py
7097f16 verified
import torch
from collections import OrderedDict
from transformers import BertTokenizer, BertForSequenceClassification
import os
HF_TOKEN = os.getenv("HF_TOKEN")
# Load the pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained("Maulidaaa/bert-safe-model", token=HF_TOKEN)
model = BertForSequenceClassification.from_pretrained("Maulidaaa/bert-safe-model", token=HF_TOKEN)
def predict(desc):
if not desc:
return "Not Safe"
inputs = tokenizer(desc, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
pred = torch.argmax(logits, dim=1).item()
return "Safe" if pred == 1 else "Not Safe"
def predict_with_description(ingredient, df):
df_match = df.copy()
df_match['INCI name_lower'] = df_match['INCI name'].str.lower()
df_match['IUPAC Name_lower'] = df_match['IUPAC Name'].str.lower()
ingredient_lower = ingredient.lower()
match_row = df_match[(df_match['INCI name_lower'] == ingredient_lower) | (df_match['IUPAC Name_lower'] == ingredient_lower)]
if not match_row.empty:
row = match_row.iloc[0]
inci_name = row['INCI name'].title()
desc = row.get('Description', '')
func = row.get('Function', '')
Restriction = row.get('Restriction')
risk_lvl = row.get('Risk Level', '')
risk_desc = row.get('Risk Description', '')
else:
inci_name = ingredient.title()
desc = "Description not found"
func = "Function not found"
Restriction = "Restriction not found"
risk_lvl = "Unknown"
risk_desc = "Risk info not available"
result = predict(desc)
return OrderedDict([
("Ingredient Name", inci_name),
("Description", desc),
("Function", func),
("Risk Level", risk_lvl),
("Restriction", Restriction), # ✅ This line fixed
("Risk Description", risk_desc),
("Prediction", result)
])