rg089 commited on
Commit
6650c1a
·
1 Parent(s): a1a9d9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -46,12 +46,12 @@ def generate(model, tokenizer, test_samples, prefix="", max_length=256):
46
  def classify(model, tokenizer, content, title):
47
  model.eval()
48
  with torch.no_grad():
49
- model_inputs = tokenizer(title, content, padding=True, truncation=True, return_tensors="pt").to(device)
50
- outputs = model(**model_inputs)
51
- logits = outputs.logits
52
- selected = logits.argmax(dim=-1).cpu().tolist()
53
- answers = [rev_mapper[sel] for sel in selected]
54
- return answers[0]
55
 
56
  def main(content, classify_source=False):
57
  output = ""
@@ -72,6 +72,6 @@ The current sources supported for classification are: The Times of India, The In
72
  placeholder = "Enter the content of the article here."
73
 
74
  iface = gradio.Interface(fn=main, inputs=[gradio.inputs.Textbox(lines=10, placeholder=placeholder, label='Article Content:'),
75
- gradio.inputs.Checkbox(default=True, label='Classify the Source:')], outputs="textbox", title=title,
76
- description=description)
77
  iface.launch()
 
46
  def classify(model, tokenizer, content, title):
47
  model.eval()
48
  with torch.no_grad():
49
+ model_inputs = tokenizer(title, content, padding=True, truncation=True, return_tensors="pt").to(device)
50
+ outputs = model(**model_inputs)
51
+ logits = outputs.logits
52
+ selected = logits.argmax(dim=-1).cpu().tolist()
53
+ answers = [rev_mapper[sel] for sel in selected]
54
+ return answers[0]
55
 
56
  def main(content, classify_source=False):
57
  output = ""
 
72
  placeholder = "Enter the content of the article here."
73
 
74
  iface = gradio.Interface(fn=main, inputs=[gradio.inputs.Textbox(lines=10, placeholder=placeholder, label='Article Content:'),
75
+ gradio.inputs.Checkbox(default=True, label='Classify the Source:')], outputs="textbox", title=title,
76
+ description=description)
77
  iface.launch()