import torch import gradio as gr import numpy as np import pickle import pandas as pd from sentence_transformers import SentenceTransformer from model import VotePredictor # <-- make sure this matches your model file # Load models main_model = VotePredictor(country_count=193) main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu")) main_model.eval() problem_model = VotePredictor(country_count=46) problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu")) problem_model.eval() # Load country encoder with open("country_encoder.pkl", "rb") as f: country_encoder = pickle.load(f) # Vectorizer vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") # Define problem countries (same as used during training) problem_countries = [ 'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON', 'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA', 'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU', 'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES', 'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE', 'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)', 'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE', 'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA', "CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER', 'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU', 'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)' ] # Vote function def predict_votes(resolution_text): # Vectorize once vec = vectorizer.encode([resolution_text]) x_tensor = torch.tensor(vec, dtype=torch.float32) countries = [] votes = [] for country in country_encoder.classes_: country_id = country_encoder.transform([country])[0] c_tensor = torch.tensor([country_id], dtype=torch.long) model = problem_model if country in problem_countries else main_model with torch.no_grad(): logit = model(x_tensor, c_tensor).squeeze() prob = torch.sigmoid(logit).item() vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes" countries.append(country) votes.append(vote) df = pd.DataFrame({ "Country": countries, "Vote": votes }).sort_values("Country") return df # Gradio UI iface = gr.Interface( fn=predict_votes, inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"), outputs=gr.Dataframe(label="Predicted Votes by Country"), title="UN Resolution Vote Predictor", description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.", live=False ) iface.launch()