File size: 4,111 Bytes
6a3d3b0
 
 
8023b76
 
 
6a3d3b0
8023b76
6a3d3b0
 
 
8023b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a3d3b0
 
 
8023b76
 
 
6a3d3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8023b76
 
 
 
6a3d3b0
 
 
8023b76
 
 
6a3d3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8023b76
6a3d3b0
 
 
 
 
 
8023b76
6a3d3b0
 
 
 
 
 
 
 
8023b76
 
 
 
6a3d3b0
8023b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# app.py
from fastapi import FastAPI, HTTPException
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from typing import List, Dict
from train import MangaRecommender, MangaDataset  # Import from train.py

app = FastAPI()

try:
    # Load model and mappings
    checkpoint = torch.load('manga_recommender.pt')
    model = MangaRecommender(
        num_users=len(checkpoint['user_mapping']),
        num_items=len(checkpoint['manga_mapping'])
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    user_mapping = checkpoint['user_mapping']
    manga_mapping = checkpoint['manga_mapping']
    reverse_manga_mapping = {v: k for k, v in manga_mapping.items()}
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    user_mapping = {}
    manga_mapping = {}
    reverse_manga_mapping = {}

@app.get("/")
async def root():
    return {"status": "running", "model_loaded": model is not None}

@app.post("/predict")
async def predict(user_id: str, top_k: int = 10):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
        
    try:
        # Get user index
        user_idx = user_mapping.get(user_id)
        if user_idx is None:
            # Handle cold start
            return {"error": "User not found"}
        
        # Get predictions
        model.eval()
        with torch.no_grad():
            user_tensor = torch.tensor([user_idx])
            predictions = model.predict(user_tensor)
            scores, indices = torch.topk(predictions[0], k=top_k)
            
            # Convert back to manga IDs
            manga_ids = [reverse_manga_mapping[idx.item()] for idx in indices]
            scores = scores.tolist()
            
            return {
                "manga_ids": manga_ids,
                "scores": scores
            }
    except Exception as e:
        raise HTTPException(
            status_code=500, 
            detail=f"Prediction error: {str(e)}"
        )

@app.post("/update")
async def update_model(ratings: List[Dict]):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
        
    try:
        # Convert ratings to training format
        df = pd.DataFrame(ratings)
        df['user_idx'] = df['user_id'].map(user_mapping)
        df['manga_idx'] = df['manga_id'].map(manga_mapping)
        
        # Create dataset
        dataset = MangaDataset(df)
        loader = DataLoader(dataset, batch_size=64)
        
        # Update model
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.MSELoss()
        
        model.train()
        total_loss = 0
        for user, item, rating in loader:
            optimizer.zero_grad()
            pred = model(user, item)
            loss = criterion(pred, rating)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # Save updated model
        torch.save({
            'model_state_dict': model.state_dict(),
            'user_mapping': user_mapping,
            'manga_mapping': manga_mapping
        }, 'manga_recommender.pt')
        
        return {
            "message": "Model updated successfully",
            "average_loss": total_loss / len(loader)
        }
    except Exception as e:
        raise HTTPException(
            status_code=500, 
            detail=f"Update error: {str(e)}"
        )

@app.get("/model-info")
async def model_info():
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
        
    return {
        "num_users": len(user_mapping),
        "num_manga": len(manga_mapping),
        "embedding_size": model.user_factors.embedding_dim
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
    raise HTTPException(status_code=500, detail=str(e))