Spaces:
Runtime error
Runtime error
File size: 1,925 Bytes
11118c6 b9fb926 11118c6 69d9c01 11118c6 b9fb926 69d9c01 11118c6 69d9c01 7616bd7 11118c6 7616bd7 69d9c01 861bf9a 571904c 11118c6 aea93fc 932fdef aea93fc f66f574 11118c6 6867fa8 11118c6 6867fa8 69d9c01 11118c6 3446916 69d9c01 3446916 11118c6 45d4b39 d63d73a 11118c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
import torch
from torch import nn
import numpy as np
import pandas as pd
from utils import compute_features
class NegBinomialModel(nn.Module):
def __init__(self, in_features):
super().__init__()
self.linear = nn.Linear(in_features, 1)
self.alpha = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
# safer activation than exp()
mu = torch.exp(torch.clamp(self.linear(x), min=-5, max=5))
alpha = torch.clamp(self.alpha, min=1e-3, max=10)
return mu.squeeze(), alpha
model = NegBinomialModel(16)
model.load_state_dict(torch.load("model_weights.pt", map_location='cpu'))
model.eval()
def predict_score(lat, lon):
# Convert input to tensor
# inputs = torch.tensor([[lat, lon]], dtype=torch.float32)
inputs = compute_features((lat,lon))
num_banks = inputs.pop("num_banks_in_radius", 0)
inputs = torch.tensor([lat,lon] + list(inputs.values()))
# Get model output
with torch.no_grad():
mu_pred, alpha = model(inputs)
# Unpack into respective values
mu_pred = mu_pred.numpy().flatten()
score = (1 * np.abs(mu_pred + 0.1)) * 100
# You can apply any post-processing here
return (
round(float(score), 3),
num_banks,
round(float(mu_pred), 3),
# "Normal Score": round(float(normal_score), 3),
)
# ======== Gradio Interface ========
interface = gr.Interface(
fn=predict_score,
inputs=[
gr.Number(label="Latitude"),
gr.Number(label="Longitude"),
],
outputs=[
gr.Number(label="Score"),
gr.Number(label="Num Current Banks"),
gr.Number(label="Num Ideal Banks")
# gr.Number(label="Normal Score"),
],
title="Bank Location Scoring Model",
description="Enter latitude and longitude to get the predicted score, number of banks, and normalized score.",
)
interface.launch()
|