Update app.py
Browse files
app.py
CHANGED
|
@@ -58,16 +58,21 @@ problem_countries = [
|
|
| 58 |
# === Prediction Function ===
|
| 59 |
def predict_votes(resolution_text):
|
| 60 |
vec = vectorizer.encode(resolution_text)
|
| 61 |
-
x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0)
|
| 62 |
|
| 63 |
countries = []
|
| 64 |
votes = []
|
| 65 |
|
| 66 |
for country in country_encoder.classes_:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
with torch.no_grad():
|
| 73 |
logit = model(x_tensor, c_tensor).squeeze()
|
|
|
|
| 58 |
# === Prediction Function ===
|
| 59 |
def predict_votes(resolution_text):
|
| 60 |
vec = vectorizer.encode(resolution_text)
|
| 61 |
+
x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0)
|
| 62 |
|
| 63 |
countries = []
|
| 64 |
votes = []
|
| 65 |
|
| 66 |
for country in country_encoder.classes_:
|
| 67 |
+
is_problem = country in problem_countries
|
| 68 |
+
model = problem_model if is_problem else main_model
|
| 69 |
+
|
| 70 |
+
if is_problem:
|
| 71 |
+
problem_index = problem_countries.index(country) # 0–45
|
| 72 |
+
c_tensor = torch.tensor([problem_index], dtype=torch.long)
|
| 73 |
+
else:
|
| 74 |
+
country_id = country_encoder.transform([country])[0] # 0–192
|
| 75 |
+
c_tensor = torch.tensor([country_id], dtype=torch.long)
|
| 76 |
|
| 77 |
with torch.no_grad():
|
| 78 |
logit = model(x_tensor, c_tensor).squeeze()
|