donsek's picture
Update app.py
9282ee1 verified
raw
history blame
3.48 kB
import torch
import gradio as gr
import numpy as np
import pickle
import pandas as pd
from model import VotePredictor
from transformers import AutoTokenizer, AutoModel
# === Vectorizer wrapper (replaces sentence-transformers) ===
class BertVectorizer:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def encode(self, text):
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = self.model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :]
return cls_embedding.squeeze().numpy()
# === Load Models ===
main_model = VotePredictor(country_count=193)
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
main_model.eval()
problem_model = VotePredictor(country_count=46)
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
problem_model.eval()
# === Load Encoder ===
with open("country_encoder.pkl", "rb") as f:
country_encoder = pickle.load(f)
# === Initialize Vectorizer ===
vectorizer = BertVectorizer()
# === List of problem countries ===
problem_countries = [
'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON',
'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA',
'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU',
'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES',
'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE',
'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)',
'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE',
'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA',
"CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER',
'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU',
'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)'
]
# === Prediction Function ===
def predict_votes(resolution_text):
vec = vectorizer.encode(resolution_text)
x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0) # batchify
countries = []
votes = []
for country in country_encoder.classes_:
country_id = country_encoder.transform([country])[0]
c_tensor = torch.tensor([country_id], dtype=torch.long)
model = problem_model if country in problem_countries else main_model
with torch.no_grad():
logit = model(x_tensor, c_tensor).squeeze()
prob = torch.sigmoid(logit).item()
vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
countries.append(country)
votes.append(vote)
df = pd.DataFrame({
"Country": countries,
"Vote": votes
}).sort_values("Country")
return df
# === Interface ===
iface = gr.Interface(
fn=predict_votes,
inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
outputs=gr.Dataframe(label="Predicted Votes by Country"),
title="UN Resolution Vote Predictor",
description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.",
live=False
)
iface.launch()