|
|
|
|
|
import sys |
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
|
|
if len(sys.argv)<2: |
|
|
print("Usage: python3 conversation.py '<your answer here>'") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
model_path = "./aq_model_b8" |
|
|
|
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_path) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
def generate_question(answer): |
|
|
input_text = "Generate a question for: " + answer |
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
|
output_ids = model.generate(input_ids, max_length=50) |
|
|
return tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
print(generate_question(sys.argv[1])) |
|
|
|