Spaces:
Runtime error
Runtime error
kz209
commited on
Commit
·
203771e
1
Parent(s):
68c64e4
update
Browse files- utils/model.py +5 -4
utils/model.py
CHANGED
|
@@ -23,16 +23,17 @@ class Model(torch.nn.Module):
|
|
| 23 |
|
| 24 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
self.name = model_name
|
|
|
|
| 26 |
logging.info(f'start loading model {self.name}')
|
| 27 |
self.pipeline = transformers.pipeline(
|
| 28 |
-
"summarization",
|
| 29 |
model=model_name,
|
| 30 |
tokenizer=self.tokenizer,
|
| 31 |
torch_dtype=torch.bfloat16,
|
| 32 |
device_map="auto",
|
| 33 |
)
|
| 34 |
-
|
| 35 |
logging.info(f'Loaded model {self.name}')
|
|
|
|
| 36 |
self.update()
|
| 37 |
|
| 38 |
@classmethod
|
|
@@ -58,6 +59,7 @@ class Model(torch.nn.Module):
|
|
| 58 |
num_return_sequences=1,
|
| 59 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 60 |
)
|
|
|
|
| 61 |
else:
|
| 62 |
sequences = self.pipeline(
|
| 63 |
content,
|
|
@@ -68,5 +70,4 @@ class Model(torch.nn.Module):
|
|
| 68 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 69 |
return_full_text=False
|
| 70 |
)
|
| 71 |
-
|
| 72 |
-
return sequences[-1]['summary_text']
|
|
|
|
| 23 |
|
| 24 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
self.name = model_name
|
| 26 |
+
|
| 27 |
logging.info(f'start loading model {self.name}')
|
| 28 |
self.pipeline = transformers.pipeline(
|
| 29 |
+
"summarization" if model_name=="google-t5/t5-large" else "text-generation",
|
| 30 |
model=model_name,
|
| 31 |
tokenizer=self.tokenizer,
|
| 32 |
torch_dtype=torch.bfloat16,
|
| 33 |
device_map="auto",
|
| 34 |
)
|
|
|
|
| 35 |
logging.info(f'Loaded model {self.name}')
|
| 36 |
+
|
| 37 |
self.update()
|
| 38 |
|
| 39 |
@classmethod
|
|
|
|
| 59 |
num_return_sequences=1,
|
| 60 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 61 |
)
|
| 62 |
+
return sequences[-1]['summary_text']
|
| 63 |
else:
|
| 64 |
sequences = self.pipeline(
|
| 65 |
content,
|
|
|
|
| 70 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 71 |
return_full_text=False
|
| 72 |
)
|
| 73 |
+
return sequences[-1]['generated_text']
|
|
|