File size: 795 Bytes
260a941
33558f8
 
 
 
 
 
 
 
a8fa556
33558f8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#!/usr/bin/python3
import sys
from transformers import T5Tokenizer, T5ForConditionalGeneration

if len(sys.argv)<2:
   print("Usage: python3 conversation.py '<your answer here>'")
   sys.exit(1)

# Define model path
model_path = "./aq_model_b8"  # Make sure this points to your saved directory

# Load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)

print("Model loaded successfully!")

def generate_question(answer):
    input_text = "Generate a question for: " + answer
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    output_ids = model.generate(input_ids, max_length=50)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(generate_question(sys.argv[1]))