Spaces:
Running
Running
| import gradio as gr | |
| import sys | |
| import pickle | |
| import json | |
| import gc | |
| import torch | |
| from pathlib import Path | |
| import gdown | |
| import os | |
| import difflib | |
| from datetime import datetime | |
| import random | |
| # Import your existing modules | |
| from utils import * | |
| from options import args | |
| from models import model_factory | |
| class LazyDict: | |
| def __init__(self, file_path): | |
| self.file_path = file_path | |
| self._data = None | |
| self._loaded = False | |
| def _load_data(self): | |
| if not self._loaded: | |
| try: | |
| with open(self.file_path, "r", encoding="utf-8") as file: | |
| self._data = json.load(file) | |
| self._loaded = True | |
| except Exception as e: | |
| print(f"Warning: Could not load {self.file_path}: {str(e)}") | |
| self._data = {} | |
| self._loaded = True | |
| def get(self, key, default=None): | |
| self._load_data() | |
| return self._data.get(key, default) | |
| def __contains__(self, key): | |
| self._load_data() | |
| return key in self._data | |
| def items(self): | |
| self._load_data() | |
| return self._data.items() | |
| def keys(self): | |
| self._load_data() | |
| return self._data.keys() | |
| def __len__(self): | |
| self._load_data() | |
| return len(self._data) | |
| class AnimeRecommendationSystem: | |
| def __init__(self, checkpoint_path, dataset_path, animes_path, images_path, mal_urls_path, type_seq_path, genres_path): | |
| self.model = None | |
| self.dataset = None | |
| self.checkpoint_path = checkpoint_path | |
| self.dataset_path = dataset_path | |
| self.animes_path = animes_path | |
| # Lazy loading ile memory optimization | |
| self.id_to_anime = LazyDict(animes_path) | |
| self.id_to_url = LazyDict(images_path) | |
| self.id_to_mal_url = LazyDict(mal_urls_path) | |
| self.id_to_type_seq = LazyDict(type_seq_path) | |
| self.id_to_genres = LazyDict(genres_path) | |
| # Cache için weak reference kullan | |
| self._cache = {} | |
| self.load_model_and_data() | |
| def load_model_and_data(self): | |
| try: | |
| print("Loading model and data...") | |
| args.bert_max_len = 128 | |
| # Dataset'i yükle | |
| dataset_path = Path(self.dataset_path) | |
| with dataset_path.open('rb') as f: | |
| self.dataset = pickle.load(f)["smap"] | |
| args.num_items = len(self.dataset) | |
| print(args.num_items) | |
| # Model'i yükle | |
| self.model = model_factory(args) | |
| self.load_checkpoint() | |
| # Garbage collection | |
| gc.collect() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {str(e)}") | |
| raise e | |
| def load_checkpoint(self): | |
| try: | |
| with open(self.checkpoint_path, 'rb') as f: | |
| checkpoint = torch.load(f, map_location='cpu', weights_only=False) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model.eval() | |
| # Checkpoint'i bellekten temizle | |
| del checkpoint | |
| gc.collect() | |
| except Exception as e: | |
| raise Exception(f"Failed to load checkpoint from {self.checkpoint_path}: {str(e)}") | |
| def get_anime_genres(self, anime_id): | |
| genres = self.id_to_genres.get(str(anime_id), [])[0] | |
| return [genre.title() for genre in genres] if genres else [] | |
| def get_anime_image_url(self, anime_id): | |
| return self.id_to_url.get(str(anime_id), None) | |
| def get_anime_mal_url(self, anime_id): | |
| return self.id_to_mal_url.get(str(anime_id), None) | |
| def _get_type(self, anime_id): | |
| type_seq_info = self.id_to_type_seq.get(str(anime_id)) | |
| if not type_seq_info or len(type_seq_info) < 2: | |
| return "Unknown" | |
| return type_seq_info[0] | |
| def find_closest_anime(self, input_name): | |
| """Finds the closest matching anime to the input name""" | |
| anime_names = {} | |
| # Collect all titles (main + alternative) | |
| for k, v in self.id_to_anime.items(): | |
| anime_id = int(k) | |
| if isinstance(v, list) and len(v) > 0: | |
| # Main title | |
| main_title = v[0] | |
| anime_names[main_title.lower().strip()] = (anime_id, main_title) | |
| # Alternative titles | |
| if len(v) > 1: | |
| for alt_title in v[1:]: | |
| if alt_title and isinstance(alt_title, str): | |
| alt_title_clean = alt_title.strip() | |
| if alt_title_clean: | |
| anime_names[alt_title_clean.lower()] = (anime_id, main_title) | |
| else: | |
| title = str(v).strip() | |
| anime_names[title.lower()] = (anime_id, title) | |
| input_lower = input_name.lower().strip() | |
| # 1. Exact match | |
| if input_lower in anime_names: | |
| return anime_names[input_lower] | |
| # 2. Substring search | |
| for anime_name_lower, (anime_id, main_title) in anime_names.items(): | |
| if input_lower in anime_name_lower: | |
| return (anime_id, main_title) | |
| # 3. Fuzzy matching | |
| anime_name_list = list(anime_names.keys()) | |
| close_matches = difflib.get_close_matches(input_lower, anime_name_list, n=1, cutoff=0.6) | |
| if close_matches: | |
| match = close_matches[0] | |
| return anime_names[match] | |
| return None | |
| def search_animes(self, query): | |
| """Search animes by query""" | |
| animes = [] | |
| query_lower = query.lower() if query else "" | |
| count = 0 | |
| for k, v in self.id_to_anime.items(): | |
| if count >= 200: # Limit for performance | |
| break | |
| anime_names = v if isinstance(v, list) else [v] | |
| match_found = False | |
| for name in anime_names: | |
| if not query or query_lower in name.lower(): | |
| match_found = True | |
| break | |
| if match_found: | |
| main_name = anime_names[0] if anime_names else "Unknown" | |
| animes.append((int(k), main_name)) | |
| count += 1 | |
| animes.sort(key=lambda x: x[1]) | |
| return animes | |
| def get_recommendations(self, favorite_anime_ids, num_recommendations=20, filters=None): | |
| try: | |
| if not favorite_anime_ids: | |
| return [], [], "Please add some favorite animes first!" | |
| smap = self.dataset | |
| inverted_smap = {v: k for k, v in smap.items()} | |
| converted_ids = [] | |
| for anime_id in favorite_anime_ids: | |
| if anime_id in smap: | |
| converted_ids.append(smap[anime_id]) | |
| if not converted_ids: | |
| return [], [], "None of the selected animes are in the model vocabulary!" | |
| # Normal recommendations | |
| target_len = 128 | |
| padded = converted_ids + [0] * (target_len - len(converted_ids)) | |
| input_tensor = torch.tensor(padded, dtype=torch.long).unsqueeze(0) | |
| max_predictions = min(75, len(inverted_smap)) | |
| with torch.no_grad(): | |
| logits = self.model(input_tensor) | |
| last_logits = logits[:, -1, :] | |
| top_scores, top_indices = torch.topk(last_logits, k=max_predictions, dim=1) | |
| recommendations = [] | |
| scores = [] | |
| for idx, score in zip(top_indices.numpy()[0], top_scores.detach().numpy()[0]): | |
| if idx in inverted_smap: | |
| anime_id = inverted_smap[idx] | |
| if anime_id in favorite_anime_ids: | |
| continue | |
| if str(anime_id) in self.id_to_anime: | |
| # Filter check | |
| if filters and not self._should_include_anime(anime_id, filters): | |
| continue | |
| anime_data = self.id_to_anime.get(str(anime_id)) | |
| anime_name = anime_data[0] if isinstance(anime_data, list) and len(anime_data) > 0 else str(anime_data) | |
| image_url = self.get_anime_image_url(anime_id) | |
| mal_url = self.get_anime_mal_url(anime_id) | |
| recommendations.append({ | |
| 'id': anime_id, | |
| 'name': anime_name, | |
| 'score': float(score), | |
| 'image_url': image_url, | |
| 'mal_url': mal_url, | |
| 'genres': self.get_anime_genres(anime_id), | |
| 'type': self._get_type(anime_id) | |
| }) | |
| scores.append(float(score)) | |
| if len(recommendations) >= num_recommendations: | |
| break | |
| # Memory cleanup | |
| del logits, last_logits, top_scores, top_indices | |
| gc.collect() | |
| return recommendations, scores, f"Found {len(recommendations)} recommendations!" | |
| except Exception as e: | |
| return [], [], f"Error during prediction: {str(e)}" | |
| def _should_include_anime(self, anime_id, filters): | |
| """Check if anime should be included based on filters""" | |
| if not filters: | |
| return True | |
| type_seq_info = self.id_to_type_seq.get(str(anime_id)) | |
| if not type_seq_info or len(type_seq_info) < 2: | |
| return True | |
| anime_type = type_seq_info[0] | |
| is_sequel = type_seq_info[1] if len(type_seq_info) > 1 else False | |
| # Sequel filter | |
| if not filters.get('show_sequels', True) and is_sequel: | |
| return False | |
| # Type filters | |
| if not filters.get('show_movies', True) and anime_type == 'MOVIE': | |
| return False | |
| if not filters.get('show_tv', True) and anime_type == 'TV': | |
| return False | |
| if not filters.get('show_ova', True) and anime_type in ['ONA', 'OVA', 'SPECIAL']: | |
| return False | |
| return True | |
| # Global recommendation system | |
| recommendation_system = None | |
| def initialize_system(): | |
| global recommendation_system | |
| if recommendation_system is None: | |
| try: | |
| args.num_items = 15687 | |
| file_ids = { | |
| "1X1jUSbE4x6DbccP7mHz-nAeGcfOjSHwe": "pretrained_bert.pth", | |
| "1J1RmuJE5OjZUO0z1irVb2M-xnvuVvvHR": "animes.json", | |
| "1xGxUCbCDUnbdnJa6Ab8wgM9cpInpeQnN": "dataset.pkl", | |
| "1PtB6o_91tNWAb4zN0xj-Kf8SKvVAJp1c": "id_to_url.json", | |
| "1xVfTB_CmeYEqq6-l_BkQXo-QAUEyBfbW": "anime_to_malurl.json", | |
| "1zMbL9TpCbODKfVT5ahiaYILlnwBZNJc1": "anime_to_typenseq.json", | |
| "1LLMRhYyw82GOz3d8SUDZF9YRJdybgAFA": "id_to_genres.json", | |
| "1bW-UlKiGplb2jTt7uD-dfIx3CMXD3iWT": "id_to_genreids.json" | |
| } | |
| def download_from_gdrive(file_id, output_path): | |
| url = f"https://drive.google.com/uc?id={file_id}" | |
| try: | |
| print(f"Downloading: {output_path}") | |
| gdown.download(url, output_path, quiet=False) | |
| print(f"Downloaded: {output_path}") | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading {output_path}: {e}") | |
| return False | |
| for file_id, filename in file_ids.items(): | |
| if not os.path.isfile(filename): | |
| download_from_gdrive(file_id, filename) | |
| recommendation_system = AnimeRecommendationSystem( | |
| "pretrained_bert.pth", | |
| "dataset.pkl", | |
| "animes.json", | |
| "id_to_url.json", | |
| "anime_to_malurl.json", | |
| "anime_to_typenseq.json", | |
| "id_to_genres.json" | |
| ) | |
| print("Recommendation system initialized successfully!") | |
| except Exception as e: | |
| print(f"Failed to initialize recommendation system: {e}") | |
| return f"Error: {str(e)}" | |
| return "System ready!" | |
| def search_and_add_anime(query, favorites_state): | |
| """Search anime and return search results""" | |
| if not recommendation_system: | |
| return "System not initialized", favorites_state, "" | |
| if not query.strip(): | |
| return "Please enter an anime name to search", favorites_state, "" | |
| # Search for anime | |
| result = recommendation_system.find_closest_anime(query.strip()) | |
| if result: | |
| anime_id, anime_name = result | |
| # Check if already in favorites | |
| if anime_id in favorites_state: | |
| return f"'{anime_name}' is already in your favorites", favorites_state, "" | |
| # Add to favorites | |
| if len(favorites_state) >= 15: | |
| return "Maximum 15 favorite animes allowed", favorites_state, "" | |
| favorites_state.append(anime_id) | |
| return f"Added '{anime_name}' to favorites", favorites_state, "" | |
| else: | |
| return f"No anime found matching '{query}'", favorites_state, "" | |
| def get_favorites_display(favorites_state): | |
| """Get display string for favorites""" | |
| if not favorites_state or not recommendation_system: | |
| return "No favorites added yet" | |
| display = "Your Favorite Animes:\n" | |
| for i, anime_id in enumerate(favorites_state, 1): | |
| anime_data = recommendation_system.id_to_anime.get(str(anime_id)) | |
| if anime_data: | |
| anime_name = anime_data[0] if isinstance(anime_data, list) else str(anime_data) | |
| display += f"{i}. {anime_name}\n" | |
| return display | |
| def clear_favorites(favorites_state): | |
| """Clear all favorites""" | |
| return "Favorites cleared", [], "" | |
| def get_recommendations_gradio(favorites_state, num_recs, show_sequels, show_movies, show_tv, show_ova): | |
| """Get recommendations for Gradio interface with HTML formatting for images""" | |
| if not recommendation_system: | |
| return "System not initialized" | |
| if not favorites_state: | |
| return "Please add some favorite animes first!" | |
| # Prepare filters | |
| filters = { | |
| 'show_sequels': show_sequels, | |
| 'show_movies': show_movies, | |
| 'show_tv': show_tv, | |
| 'show_ova': show_ova | |
| } | |
| recommendations, scores, message = recommendation_system.get_recommendations( | |
| favorites_state, | |
| num_recommendations=int(num_recs), | |
| filters=filters | |
| ) | |
| if not recommendations: | |
| return f"No recommendations found. {message}" | |
| # Format recommendations with HTML and images | |
| result = f"<div style='padding: 20px;'><h2>🎌 {message}</h2><br>" | |
| for i, rec in enumerate(recommendations, 1): | |
| # Create HTML card for each recommendation | |
| result += f""" | |
| <div style='border: 2px solid #e0e0e0; border-radius: 10px; padding: 15px; margin: 15px 0; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); box-shadow: 0 4px 6px rgba(0,0,0,0.1);'> | |
| <div style='display: flex; align-items: flex-start; gap: 15px;'> | |
| <div style='flex-shrink: 0;'> | |
| """ | |
| # Add image if available | |
| if rec.get('image_url'): | |
| result += f""" | |
| <img src='{rec["image_url"]}' alt='{rec["name"]}' | |
| style='width: 120px; height: 160px; object-fit: cover; border-radius: 8px; border: 2px solid #fff; box-shadow: 0 2px 4px rgba(0,0,0,0.2);' | |
| onerror="this.style.display='none';"> | |
| """ | |
| else: | |
| result += """ | |
| <div style='width: 120px; height: 160px; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); border-radius: 8px; display: flex; align-items: center; justify-content: center; color: white; font-weight: bold; text-align: center; border: 2px solid #fff; box-shadow: 0 2px 4px rgba(0,0,0,0.2);'> | |
| No Image | |
| </div> | |
| """ | |
| result += f""" | |
| </div> | |
| <div style='flex: 1; min-width: 0;'> | |
| <h3 style='margin: 0 0 10px 0; color: #2c3e50; font-size: 1.2em; line-height: 1.3;'>{i}. {rec['name']}</h3> | |
| <div style='margin-bottom: 8px;'> | |
| <span style='background: #e74c3c; color: white; padding: 4px 8px; border-radius: 15px; font-size: 0.85em; font-weight: bold;'> | |
| Type: {rec.get('type', 'Unknown')} | |
| </span> | |
| </div> | |
| """ | |
| # Add genres | |
| if rec['genres']: | |
| result += f""" | |
| <div style='margin-bottom: 10px;'> | |
| <strong style='color: #7f8c8d;'>Genres:</strong> | |
| <div style='margin-top: 4px;'> | |
| """ | |
| for genre in rec['genres']: | |
| result += f""" | |
| <span style='background: #95a5a6; color: white; padding: 2px 6px; border-radius: 10px; font-size: 0.8em; margin-right: 4px; margin-bottom: 2px; display: inline-block;'> | |
| {genre} | |
| </span> | |
| """ | |
| result += "</div></div>" | |
| # Add MyAnimeList link | |
| if rec.get('mal_url'): | |
| result += f""" | |
| <div> | |
| <a href='{rec["mal_url"]}' target='_blank' | |
| style='background: #2e7d32; color: white; padding: 8px 12px; border-radius: 6px; text-decoration: none; font-weight: bold; font-size: 0.9em; display: inline-block;'> | |
| 📖 View on MyAnimeList | |
| </a> | |
| </div> | |
| """ | |
| result += """ | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| result += "</div>" | |
| return result | |
| def create_interface(): | |
| # Initialize system | |
| init_status = initialize_system() | |
| print(init_status) | |
| with gr.Blocks(title="Anime Recommendation System", theme=gr.themes.Soft()) as demo: | |
| # State for favorites | |
| favorites_state = gr.State([]) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>🎌 Anime Recommendation System</h1> | |
| <p>Add your favorite animes and get personalized recommendations!</p> | |
| </div> | |
| """) | |
| with gr.Tab("Add Favorites"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| search_input = gr.Textbox( | |
| label="Search Anime", | |
| placeholder="Enter anime name (e.g., 'Mushoku Tensei', 'Attack on Titan')", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| add_btn = gr.Button("Add to Favorites", variant="primary") | |
| clear_btn = gr.Button("Clear All Favorites", variant="secondary") | |
| with gr.Column(scale=2): | |
| status_output = gr.Textbox(label="Status", lines=2) | |
| favorites_display = gr.Textbox( | |
| label="Your Favorites", | |
| lines=10, | |
| interactive=False, | |
| value="No favorites added yet" | |
| ) | |
| with gr.Tab("Get Recommendations"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Recommendation Settings") | |
| num_recs = gr.Slider( | |
| minimum=5, | |
| maximum=50, | |
| value=20, | |
| step=5, | |
| label="Number of Recommendations" | |
| ) | |
| gr.Markdown("### Filters") | |
| show_movies = gr.Checkbox(label="Include Movies", value=True) | |
| show_tv = gr.Checkbox(label="Include TV Series", value=True) | |
| show_ova = gr.Checkbox(label="Include OVA/ONA/Special", value=True) | |
| show_sequels = gr.Checkbox(label="Include Sequels", value=True) | |
| recommend_btn = gr.Button("Get Recommendations", variant="primary") | |
| with gr.Column(scale=2): | |
| recommendations_output = gr.HTML( | |
| label="Recommendations", | |
| value="<div style='padding: 20px; text-align: center; color: #7f8c8d;'>Add some favorite animes and click 'Get Recommendations'</div>" | |
| ) | |
| # Event handlers | |
| add_btn.click( | |
| fn=search_and_add_anime, | |
| inputs=[search_input, favorites_state], | |
| outputs=[status_output, favorites_state, search_input] | |
| ).then( | |
| fn=get_favorites_display, | |
| inputs=[favorites_state], | |
| outputs=[favorites_display] | |
| ) | |
| clear_btn.click( | |
| fn=clear_favorites, | |
| inputs=[favorites_state], | |
| outputs=[status_output, favorites_state, search_input] | |
| ).then( | |
| fn=get_favorites_display, | |
| inputs=[favorites_state], | |
| outputs=[favorites_display] | |
| ) | |
| recommend_btn.click( | |
| fn=get_recommendations_gradio, | |
| inputs=[ | |
| favorites_state, num_recs, show_sequels, | |
| show_movies, show_tv, show_ova | |
| ], | |
| outputs=[recommendations_output] | |
| ) | |
| # Examples | |
| with gr.Tab("Examples"): | |
| gr.Markdown(""" | |
| ### How to use: | |
| 1. **Add Favorites**: Search and add your favorite animes | |
| 2. **Set Filters**: Choose what types of anime to include | |
| 3. **Get Recommendations**: Click to get personalized suggestions | |
| ### Example Searches: | |
| - Mushoku Tensei | |
| - Attack on Titan | |
| - Demon Slayer | |
| - Your Name | |
| - Spirited Away | |
| - One Piece | |
| - Naruto | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |