dangermouse77's picture
Update test_aqmodel.py
260a941 verified
raw
history blame contribute delete
795 Bytes
#!/usr/bin/python3
import sys
from transformers import T5Tokenizer, T5ForConditionalGeneration
if len(sys.argv)<2:
print("Usage: python3 conversation.py '<your answer here>'")
sys.exit(1)
# Define model path
model_path = "./aq_model_b8" # Make sure this points to your saved directory
# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)
print("Model loaded successfully!")
def generate_question(answer):
input_text = "Generate a question for: " + answer
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, max_length=50)
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generate_question(sys.argv[1]))