donsek commited on
Commit
0de37c8
·
verified ·
1 Parent(s): 9282ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -58,16 +58,21 @@ problem_countries = [
58
  # === Prediction Function ===
59
  def predict_votes(resolution_text):
60
  vec = vectorizer.encode(resolution_text)
61
- x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0) # batchify
62
 
63
  countries = []
64
  votes = []
65
 
66
  for country in country_encoder.classes_:
67
- country_id = country_encoder.transform([country])[0]
68
- c_tensor = torch.tensor([country_id], dtype=torch.long)
69
-
70
- model = problem_model if country in problem_countries else main_model
 
 
 
 
 
71
 
72
  with torch.no_grad():
73
  logit = model(x_tensor, c_tensor).squeeze()
 
58
  # === Prediction Function ===
59
  def predict_votes(resolution_text):
60
  vec = vectorizer.encode(resolution_text)
61
+ x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0)
62
 
63
  countries = []
64
  votes = []
65
 
66
  for country in country_encoder.classes_:
67
+ is_problem = country in problem_countries
68
+ model = problem_model if is_problem else main_model
69
+
70
+ if is_problem:
71
+ problem_index = problem_countries.index(country) # 0–45
72
+ c_tensor = torch.tensor([problem_index], dtype=torch.long)
73
+ else:
74
+ country_id = country_encoder.transform([country])[0] # 0–192
75
+ c_tensor = torch.tensor([country_id], dtype=torch.long)
76
 
77
  with torch.no_grad():
78
  logit = model(x_tensor, c_tensor).squeeze()