File size: 2,446 Bytes
300f197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

class LocalQuestionGenerator:
    def __init__(self):
        print("Loading local question generation model...")
        self.device = 0 if torch.cuda.is_available() else -1  # Use GPU if available
        
        # Load the model and tokenizer
        model_name = "valhalla/t5-base-qa-qg-hl"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        # Initialize the pipeline
        self.generator = pipeline(
            "text2text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device=self.device
        )
        print("Model loaded successfully!")
    
    def generate_questions(self, text, num_questions=5, max_length=64):
        """Generate questions from the given text."""
        if not text.strip():
            return []
            
        try:
            # Prepare the input text
            input_text = f"generate questions: {text}"
            
            # Generate questions
            results = self.generator(
                input_text,
                max_length=max_length,
                num_return_sequences=num_questions,
                num_beams=5,
                early_stopping=True
            )
            
            # Extract and clean the generated questions
            questions = [result['generated_text'].strip() for result in results]
            return questions
            
        except Exception as e:
            print(f"Error generating questions: {str(e)}")
            return []

# Example usage
if __name__ == "__main__":
    # Initialize the generator
    qg = LocalQuestionGenerator()
    
    # Sample text
    sample_text = """
    Machine learning is a branch of artificial intelligence that focuses on building systems 
    that learn from data. These systems can improve their performance over time without being 
    explicitly programmed. There are three main types of machine learning: supervised learning, 
    unsupervised learning, and reinforcement learning.
    """
    
    # Generate questions
    print("\nGenerating questions...")
    questions = qg.generate_questions(sample_text, num_questions=3)
    
    # Print the results
    print("\nGenerated Questions:")
    for i, q in enumerate(questions, 1):
        print(f"{i}. {q}")