Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import DistilBertTokenizer, DistilBertModel | |
| import torch.nn as nn | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from analisis import mostrar_analisis_aerolinea | |
| # ------------------------------- | |
| # Modelo personalizado regresor | |
| # ------------------------------- | |
| class DistilBertRegressor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| self.regressor = nn.Linear(self.bert.config.hidden_size, 1) | |
| def forward(self, input_ids=None, attention_mask=None): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.last_hidden_state[:, 0] | |
| return self.regressor(pooled_output) | |
| # ------------------------------- | |
| # Cargar modelo y tokenizer con caché moderna | |
| # ------------------------------- | |
| def load_model(): | |
| model = DistilBertRegressor() | |
| # Cargar pesos entrenados desde Hugging Face | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| "https://huggingface.co/ManuelMC/Modelo_TFM/resolve/main/pytorch_model.bin", | |
| map_location="cpu" | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| tokenizer = DistilBertTokenizer.from_pretrained("ManuelMC/Modelo_TFM") | |
| return model, tokenizer | |
| # ------------------------------- | |
| # Cargar dataset desde Hugging Face | |
| # ------------------------------- | |
| def cargar_dataset(): | |
| dataset = load_dataset("ManuelMC/dataset_scraping", split="train") | |
| return dataset.to_pandas() | |
| df_reviews = cargar_dataset() | |
| model, tokenizer = load_model() | |
| # ------------------------------- | |
| # Interfaz de usuario | |
| # ------------------------------- | |
| st.title("✈️ Airline Review Rating Prediction") | |
| airlines = ["Iberia", "Vueling", "Ryanair", "Air Europa", "EasyJet", "Eurowings", | |
| "Grupo AirFrance-KLM", "Grupo IAG", "Iberia Express", "Jet2.com", | |
| "Lufthansa", "Norwegian", "Pegasus Airlines", "SAS", "Turkish Airlines", "Wizz Air"] | |
| selected_airline = st.selectbox("Select an airline", airlines) | |
| review = st.text_area("Write your flight review here:") | |
| if st.button("Predict Rating"): | |
| inputs = tokenizer(review, return_tensors="pt", truncation=True, padding=True, max_length=256) | |
| with torch.no_grad(): | |
| output = model(**inputs).squeeze() | |
| rating = output.item() * 1.65 # Scale adjustment | |
| rating = min(rating, 10) | |
| rating = round(rating, 2) | |
| with st.expander(f"Results for {selected_airline}"): | |
| st.success(f"🌟 Predicted Rating: **{rating}/10**") | |
| mostrar_analisis_aerolinea(df_reviews, selected_airline) | |