Spaces:
Build error
Build error
| import gradio as gr | |
| import pickle | |
| import pandas as pd | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| # === Define the model class === | |
| class FactorizationMachineMultiTask(nn.Module): | |
| def __init__(self, cardinalities, k=20, n_outputs=4): | |
| super().__init__() | |
| self.n_features = len(cardinalities) | |
| self.k = k | |
| self.n_outputs = n_outputs | |
| self.linear_embeddings = nn.ModuleList([ | |
| nn.Embedding(num_categories, n_outputs) | |
| for num_categories in cardinalities | |
| ]) | |
| self.factor_embeddings = nn.ModuleList([ | |
| nn.Embedding(num_categories, k * n_outputs) | |
| for num_categories in cardinalities | |
| ]) | |
| for emb in self.linear_embeddings: | |
| nn.init.normal_(emb.weight, mean=0.0, std=0.01) | |
| for emb in self.factor_embeddings: | |
| nn.init.normal_(emb.weight, mean=0.0, std=0.01) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| batch_size = x.size(0) | |
| linear_terms = [] | |
| factor_terms = [] | |
| for i in range(self.n_features): | |
| lin_i = self.linear_embeddings[i](x[:, i]) | |
| linear_terms.append(lin_i) | |
| v_i = self.factor_embeddings[i](x[:, i]) | |
| factor_terms.append(v_i) | |
| linear_part = torch.stack(linear_terms, dim=2).sum(dim=2) | |
| factor_stack = torch.stack(factor_terms, dim=1) | |
| factor_stack = factor_stack.view(batch_size, self.n_features, self.k, self.n_outputs) | |
| sum_v = factor_stack.sum(dim=1) | |
| sum_v_square = sum_v * sum_v | |
| v_square_sum = (factor_stack * factor_stack).sum(dim=1) | |
| interaction_part = 0.5 * (sum_v_square - v_square_sum).sum(dim=1) | |
| out = linear_part + interaction_part | |
| out = 1 + 9 * self.sigmoid(out) | |
| return out | |
| # === Load encoders and data first === | |
| with open("user_encoder.pkl", "rb") as f: | |
| user_encoder = pickle.load(f) | |
| with open("title_encoder.pkl", "rb") as f: | |
| title_encoder = pickle.load(f) | |
| df = pd.read_csv("df.csv") | |
| # === Instantiate and load the model === | |
| n_users = len(user_encoder.classes_) | |
| n_items = len(title_encoder.classes_) | |
| cardinalities = [n_users, n_items] | |
| fm_model = FactorizationMachineMultiTask(cardinalities, k=50, n_outputs=4) | |
| fm_model.load_state_dict(torch.load("fm_model.pth", map_location=torch.device("cpu"))) | |
| fm_model.eval() | |
| # === Prediction functions === | |
| def predict_subratings(user_name, drama_title, fm_model, user_encoder, title_encoder): | |
| if user_name not in user_encoder.classes_: | |
| return None | |
| if drama_title not in title_encoder.classes_: | |
| return None | |
| user_id = user_encoder.transform([user_name])[0] | |
| item_id = title_encoder.transform([drama_title])[0] | |
| row = pd.DataFrame([[user_id, item_id]], columns=["user_id", "item_id"]) | |
| row_tensor = torch.tensor(row.values, dtype=torch.long) | |
| with torch.no_grad(): | |
| preds = fm_model(row_tensor) | |
| return preds.view(-1).tolist() | |
| def get_top_recommendations_web(user_name, top_n=10): | |
| if user_name not in user_encoder.classes_: | |
| return f"Unknown user: {user_name}" | |
| user_watched = df.groupby("Username")["Title"].apply(set).to_dict() | |
| watched_shows = user_watched.get(user_name, set()) | |
| recommendations = [] | |
| for drama_title in title_encoder.classes_: | |
| if drama_title in watched_shows: | |
| continue | |
| predicted = predict_subratings(user_name, drama_title, fm_model, user_encoder, title_encoder) | |
| if predicted is not None: | |
| avg_rating = np.mean(predicted) | |
| recommendations.append((drama_title, avg_rating)) | |
| recommendations.sort(key=lambda x: x[1], reverse=True) | |
| if not recommendations: | |
| return "No recommendations available." | |
| return "\n".join([ | |
| f"{i+1}. {title} — Predicted Rating: {rating:.2f}" | |
| for i, (title, rating) in enumerate(recommendations[:top_n]) | |
| ]) | |
| # === Gradio interface === | |
| demo = gr.Interface( | |
| fn=get_top_recommendations_web, | |
| inputs=gr.Textbox(label="Enter your username"), | |
| outputs=gr.Textbox(label="Top Recommendations"), | |
| title="K-Drama Recommender", | |
| description="Type your username to get personalized drama suggestions.", | |
| ) | |
| demo.launch(share=True) | |