File size: 3,199 Bytes
99a265f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import gradio as gr
import numpy as np
import pickle
import pandas as pd
from sentence_transformers import SentenceTransformer
from model import VotePredictor  # <-- make sure this matches your model file

# Load models
main_model = VotePredictor(country_count=193)
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
main_model.eval()

problem_model = VotePredictor(country_count=46)
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
problem_model.eval()

# Load country encoder
with open("country_encoder.pkl", "rb") as f:
    country_encoder = pickle.load(f)

# Vectorizer
vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Define problem countries (same as used during training)
problem_countries = [ 
    'SURINAME',
    'TURKMENISTAN',
    'MARSHALL ISLANDS',
    'MYANMAR',
    'GABON',
    'CENTRAL AFRICAN REPUBLIC',
    'ISRAEL',
    'REPUBLIC OF THE CONGO',
    'LIBERIA',
    'SOMALIA',
    'CANADA',
    "LAO PEOPLE'S DEMOCRATIC REPUBLIC",
    'TUVALU',
    'DEMOCRATIC REPUBLIC OF THE CONGO',
    'MONTENEGRO',
    'VANUATU',
    'UNITED STATES',
    'TÜRKİYE',
    'SEYCHELLES',
    'SERBIA',
    'CABO VERDE',
    'VENEZUELA (BOLIVARIAN REPUBLIC OF)',
    'KIRIBATI',
    'IRAN (ISLAMIC REPUBLIC OF)',
    'SOUTH SUDAN',
    'ALBANIA',
    'CZECHIA',
    'DOMINICA',
    'SAO TOME AND PRINCIPE',
    'ESWATINI',
    'CHAD',
    'EQUATORIAL GUINEA',
    'GAMBIA',
    'LIBYA',
    "CÔTE D'IVOIRE",
    'SAINT CHRISTOPHER AND NEVIS',
    'RWANDA',
    'TONGA',
    'NIGER',
    'MICRONESIA (FEDERATED STATES OF)',
    'SYRIAN ARAB REPUBLIC',
    'NAURU',
    'PALAU',
    'NORTH MACEDONIA',
    'NETHERLANDS',
    'BOLIVIA (PLURINATIONAL STATE OF)'
]

# Vote function
def predict_votes(resolution_text):
    # Vectorize once
    vec = vectorizer.encode([resolution_text])
    x_tensor = torch.tensor(vec, dtype=torch.float32)

    countries = []
    votes = []

    for country in country_encoder.classes_:
        country_id = country_encoder.transform([country])[0]
        c_tensor = torch.tensor([country_id], dtype=torch.long)

        model = problem_model if country in problem_countries else main_model

        with torch.no_grad():
            logit = model(x_tensor, c_tensor).squeeze()
            prob = torch.sigmoid(logit).item()
            vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"

        countries.append(country)
        votes.append(vote)

    df = pd.DataFrame({
        "Country": countries,
        "Vote": votes
    }).sort_values("Country")

    return df

# Gradio UI
iface = gr.Interface(
    fn=predict_votes,
    inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
    outputs=gr.Dataframe(label="Predicted Votes by Country"),
    title="UN Resolution Vote Predictor",
    description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.",
    live=False
)

iface.launch()