saharM commited on
Commit
c4081b3
·
1 Parent(s): dc44059

move to cpu

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  import re
3
- import torch
4
  import contractions
5
  import pandas as pd
6
  from transformers import BartTokenizer, BartForConditionalGeneration
@@ -77,9 +76,8 @@ st.markdown("""
77
  """, unsafe_allow_html=True)
78
 
79
  # Load model and tokenizer
80
- device = torch.device("cpu")
81
  MODEL_PATH = "./models/fine-tuned_bart_base"
82
- model = BartForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
83
  tokenizer = BartTokenizer.from_pretrained(MODEL_PATH)
84
 
85
  #Helper functions
@@ -279,6 +277,7 @@ def summarize_text(txt):
279
  txt = preprocess_text(txt)
280
  txt = anonymize_speakers(txt, speaker_1, speaker_2)
281
  inputs = tokenizer(txt, return_tensors="pt")
 
282
  summary_ids = model.generate(**inputs)
283
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
284
  summary = deanonymize_speakers(summary, speaker_1, speaker_2)
 
1
  import streamlit as st
2
  import re
 
3
  import contractions
4
  import pandas as pd
5
  from transformers import BartTokenizer, BartForConditionalGeneration
 
76
  """, unsafe_allow_html=True)
77
 
78
  # Load model and tokenizer
 
79
  MODEL_PATH = "./models/fine-tuned_bart_base"
80
+ model = BartForConditionalGeneration.from_pretrained(MODEL_PATH).cpu()
81
  tokenizer = BartTokenizer.from_pretrained(MODEL_PATH)
82
 
83
  #Helper functions
 
277
  txt = preprocess_text(txt)
278
  txt = anonymize_speakers(txt, speaker_1, speaker_2)
279
  inputs = tokenizer(txt, return_tensors="pt")
280
+ inputs = {k: v.cpu() for k, v in inputs.items()}
281
  summary_ids = model.generate(**inputs)
282
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
283
  summary = deanonymize_speakers(summary, speaker_1, speaker_2)