donsek's picture
Upload 2 files
99a265f verified
raw
history blame
3.2 kB
import torch
import gradio as gr
import numpy as np
import pickle
import pandas as pd
from sentence_transformers import SentenceTransformer
from model import VotePredictor # <-- make sure this matches your model file
# Load models
main_model = VotePredictor(country_count=193)
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
main_model.eval()
problem_model = VotePredictor(country_count=46)
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
problem_model.eval()
# Load country encoder
with open("country_encoder.pkl", "rb") as f:
country_encoder = pickle.load(f)
# Vectorizer
vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Define problem countries (same as used during training)
problem_countries = [
'SURINAME',
'TURKMENISTAN',
'MARSHALL ISLANDS',
'MYANMAR',
'GABON',
'CENTRAL AFRICAN REPUBLIC',
'ISRAEL',
'REPUBLIC OF THE CONGO',
'LIBERIA',
'SOMALIA',
'CANADA',
"LAO PEOPLE'S DEMOCRATIC REPUBLIC",
'TUVALU',
'DEMOCRATIC REPUBLIC OF THE CONGO',
'MONTENEGRO',
'VANUATU',
'UNITED STATES',
'TÜRKİYE',
'SEYCHELLES',
'SERBIA',
'CABO VERDE',
'VENEZUELA (BOLIVARIAN REPUBLIC OF)',
'KIRIBATI',
'IRAN (ISLAMIC REPUBLIC OF)',
'SOUTH SUDAN',
'ALBANIA',
'CZECHIA',
'DOMINICA',
'SAO TOME AND PRINCIPE',
'ESWATINI',
'CHAD',
'EQUATORIAL GUINEA',
'GAMBIA',
'LIBYA',
"CÔTE D'IVOIRE",
'SAINT CHRISTOPHER AND NEVIS',
'RWANDA',
'TONGA',
'NIGER',
'MICRONESIA (FEDERATED STATES OF)',
'SYRIAN ARAB REPUBLIC',
'NAURU',
'PALAU',
'NORTH MACEDONIA',
'NETHERLANDS',
'BOLIVIA (PLURINATIONAL STATE OF)'
]
# Vote function
def predict_votes(resolution_text):
# Vectorize once
vec = vectorizer.encode([resolution_text])
x_tensor = torch.tensor(vec, dtype=torch.float32)
countries = []
votes = []
for country in country_encoder.classes_:
country_id = country_encoder.transform([country])[0]
c_tensor = torch.tensor([country_id], dtype=torch.long)
model = problem_model if country in problem_countries else main_model
with torch.no_grad():
logit = model(x_tensor, c_tensor).squeeze()
prob = torch.sigmoid(logit).item()
vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
countries.append(country)
votes.append(vote)
df = pd.DataFrame({
"Country": countries,
"Vote": votes
}).sort_values("Country")
return df
# Gradio UI
iface = gr.Interface(
fn=predict_votes,
inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
outputs=gr.Dataframe(label="Predicted Votes by Country"),
title="UN Resolution Vote Predictor",
description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.",
live=False
)
iface.launch()