Espacio_TFM / app.py
ManuelMC's picture
Update app.py
0a12911 verified
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn as nn
import pandas as pd
from datasets import load_dataset
from analisis import mostrar_analisis_aerolinea
# -------------------------------
# Modelo personalizado regresor
# -------------------------------
class DistilBertRegressor(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
self.regressor = nn.Linear(self.bert.config.hidden_size, 1)
def forward(self, input_ids=None, attention_mask=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0]
return self.regressor(pooled_output)
# -------------------------------
# Cargar modelo y tokenizer con caché moderna
# -------------------------------
@st.cache_resource
def load_model():
model = DistilBertRegressor()
# Cargar pesos entrenados desde Hugging Face
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/ManuelMC/Modelo_TFM/resolve/main/pytorch_model.bin",
map_location="cpu"
)
model.load_state_dict(state_dict)
model.eval()
tokenizer = DistilBertTokenizer.from_pretrained("ManuelMC/Modelo_TFM")
return model, tokenizer
# -------------------------------
# Cargar dataset desde Hugging Face
# -------------------------------
@st.cache_data
def cargar_dataset():
dataset = load_dataset("ManuelMC/dataset_scraping", split="train")
return dataset.to_pandas()
df_reviews = cargar_dataset()
model, tokenizer = load_model()
# -------------------------------
# Interfaz de usuario
# -------------------------------
st.title("✈️ Airline Review Rating Prediction")
airlines = ["Iberia", "Vueling", "Ryanair", "Air Europa", "EasyJet", "Eurowings",
"Grupo AirFrance-KLM", "Grupo IAG", "Iberia Express", "Jet2.com",
"Lufthansa", "Norwegian", "Pegasus Airlines", "SAS", "Turkish Airlines", "Wizz Air"]
selected_airline = st.selectbox("Select an airline", airlines)
review = st.text_area("Write your flight review here:")
if st.button("Predict Rating"):
inputs = tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
output = model(**inputs).squeeze()
rating = output.item() * 1.65 # Scale adjustment
rating = min(rating, 10)
rating = round(rating, 2)
with st.expander(f"Results for {selected_airline}"):
st.success(f"🌟 Predicted Rating: **{rating}/10**")
mostrar_analisis_aerolinea(df_reviews, selected_airline)