rg089 commited on
Commit
a1a9d9b
·
1 Parent(s): f84e8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -43,15 +43,15 @@ def generate(model, tokenizer, test_samples, prefix="", max_length=256):
43
  output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
44
  return output_str[0]
45
 
46
- def classify(model, tokenizer, content, title):
47
  model.eval()
48
  with torch.no_grad():
49
- model_inputs = tokenizer(title, content, padding=True, truncation=True, return_tensors="pt").to(device)
50
- outputs = model(**model_inputs)
51
- logits = outputs.logits
52
- selected = logits.argmax(dim=-1).cpu().tolist()
53
- answers = [rev_mapper[sel] for sel in selected]
54
- return answers[0]
55
 
56
  def main(content, classify_source=False):
57
  output = ""
 
43
  output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
44
  return output_str[0]
45
 
46
+ def classify(model, tokenizer, content, title):
47
  model.eval()
48
  with torch.no_grad():
49
+ model_inputs = tokenizer(title, content, padding=True, truncation=True, return_tensors="pt").to(device)
50
+ outputs = model(**model_inputs)
51
+ logits = outputs.logits
52
+ selected = logits.argmax(dim=-1).cpu().tolist()
53
+ answers = [rev_mapper[sel] for sel in selected]
54
+ return answers[0]
55
 
56
  def main(content, classify_source=False):
57
  output = ""