Update modeling_ostlm.py
Browse files- modeling_ostlm.py +0 -24
modeling_ostlm.py
CHANGED
|
@@ -105,27 +105,3 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 105 |
AutoConfig.register("ostlm", OSTLMConfig)
|
| 106 |
AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)
|
| 107 |
|
| 108 |
-
def translate(text, model_path="./ostlm_v1_final"):
|
| 109 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 110 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
| 111 |
-
model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
|
| 112 |
-
|
| 113 |
-
inputs = tokenizer(text, return_tensors="pt", padding=True).to(model.device)
|
| 114 |
-
|
| 115 |
-
with torch.no_grad():
|
| 116 |
-
outputs = model.generate(
|
| 117 |
-
input_ids=inputs["input_ids"],
|
| 118 |
-
max_length=4,
|
| 119 |
-
num_beams=3, # Beam search 诪爪讜诪爪诐 讬讜转专 诇讬爪讬讘讜转
|
| 120 |
-
decoder_start_token_id=tokenizer.cls_token_id,
|
| 121 |
-
eos_token_id=tokenizer.sep_token_id,
|
| 122 |
-
pad_token_id=tokenizer.pad_token_id,
|
| 123 |
-
)
|
| 124 |
-
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 125 |
-
|
| 126 |
-
print("--- 馃 OSTLM v1: Structure Fixed ---")
|
| 127 |
-
try:
|
| 128 |
-
print(f"Result: {translate('Who?')}")
|
| 129 |
-
except Exception as e:
|
| 130 |
-
import traceback
|
| 131 |
-
traceback.print_exc()
|
|
|
|
| 105 |
AutoConfig.register("ostlm", OSTLMConfig)
|
| 106 |
AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|