MindVR commited on
Commit
77464a5
·
verified ·
1 Parent(s): fad09d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -11,7 +11,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
11
  # Summarization model
12
  summarizer_model_id = "facebook/bart-large-cnn"
13
  summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
14
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model_id)
 
 
 
 
15
  summarizer_model.to(device)
16
 
17
  def summarize_text(text: str, max_length=150) -> str:
@@ -33,6 +37,7 @@ model_id = "MindVR/JohnTran_Fine-tune"
33
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
 
36
  device_map="auto",
37
  low_cpu_mem_usage=True,
38
  token=HF_TOKEN
@@ -61,7 +66,7 @@ def chat(
61
  with torch.no_grad():
62
  output = model.generate(
63
  input_ids,
64
- max_new_tokens=1000,
65
  do_sample=True,
66
  top_p=0.95,
67
  temperature=0.7,
 
11
  # Summarization model
12
  summarizer_model_id = "facebook/bart-large-cnn"
13
  summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
14
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(
15
+ summarizer_model_id,
16
+ torch_dtype=torch.float16,
17
+ device_map="auto"
18
+ )
19
  summarizer_model.to(device)
20
 
21
  def summarize_text(text: str, max_length=150) -> str:
 
37
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
38
  model = AutoModelForCausalLM.from_pretrained(
39
  model_id,
40
+ torch_dtype=torch.float16,
41
  device_map="auto",
42
  low_cpu_mem_usage=True,
43
  token=HF_TOKEN
 
66
  with torch.no_grad():
67
  output = model.generate(
68
  input_ids,
69
+ max_new_tokens=256,
70
  do_sample=True,
71
  top_p=0.95,
72
  temperature=0.7,