Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -43,15 +43,15 @@ def generate(model, tokenizer, test_samples, prefix="", max_length=256):
|
|
| 43 |
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 44 |
return output_str[0]
|
| 45 |
|
| 46 |
-
|
| 47 |
model.eval()
|
| 48 |
with torch.no_grad():
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
|
| 56 |
def main(content, classify_source=False):
|
| 57 |
output = ""
|
|
|
|
| 43 |
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 44 |
return output_str[0]
|
| 45 |
|
| 46 |
+
def classify(model, tokenizer, content, title):
|
| 47 |
model.eval()
|
| 48 |
with torch.no_grad():
|
| 49 |
+
model_inputs = tokenizer(title, content, padding=True, truncation=True, return_tensors="pt").to(device)
|
| 50 |
+
outputs = model(**model_inputs)
|
| 51 |
+
logits = outputs.logits
|
| 52 |
+
selected = logits.argmax(dim=-1).cpu().tolist()
|
| 53 |
+
answers = [rev_mapper[sel] for sel in selected]
|
| 54 |
+
return answers[0]
|
| 55 |
|
| 56 |
def main(content, classify_source=False):
|
| 57 |
output = ""
|