File size: 5,852 Bytes
3736c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Optional
import config


class ModelInference:
    """Handle model loading and inference for text generation."""
    
    def __init__(self, model_name: str = None, use_4bit: bool = True):
        """

        Initialize the model for inference.

        RAG Mode: Uses pre-trained model directly (no training needed!).

        

        Args:

            model_name: Name or path of the model (uses pre-trained by default)

            use_4bit: Whether to use 4-bit quantization for efficiency

        """
        # Use pre-trained model if specified, otherwise check for fine-tuned model
        if config.USE_PRETRAINED or not Path(config.MODEL_PATH).exists():
            self.model_name = model_name or config.MODEL_NAME
        else:
            self.model_name = model_name or config.MODEL_PATH
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        print(f"Loading model: {self.model_name}")
        print(f"Device: {self.device}")
        
        # Configure quantization for efficiency
        if use_4bit and self.device == "cuda":
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                device_map="auto" if self.device == "cuda" else None,
                trust_remote_code=True
            )
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True
        )
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        self.model.eval()
    
    def generate_response(

        self,

        prompt: str,

        context: str = "",

        use_case: str = "explanation",

        temperature: float = None,

        max_tokens: int = None

    ) -> str:
        """

        Generate a response based on the prompt and context.

        

        Args:

            prompt: User query

            context: Retrieved context from documents

            use_case: Type of response (explanation, summary, qa, notes)

            temperature: Sampling temperature

            max_tokens: Maximum number of tokens to generate

            

        Returns:

            Generated text response

        """
        temperature = temperature or config.TEMPERATURE
        max_tokens = max_tokens or config.MAX_TOKENS
        
        # Create system prompt based on use case
        system_prompts = {
            "explanation": "You are an expert tutor. Provide detailed, clear explanations of concepts based on the given context.",
            "summary": "You are a summarization expert. Create concise, well-structured summaries of the provided content.",
            "qa": "You are a knowledgeable assistant. Answer questions accurately based on the given context.",
            "notes": "You are a study notes specialist. Create well-organized, structured study notes from the content."
        }
        
        system_prompt = system_prompts.get(use_case, system_prompts["explanation"])
        
        # Format the full prompt
        full_prompt = self._format_prompt(system_prompt, context, prompt)
        
        # Tokenize
        inputs = self.tokenizer(
            full_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=2048
        ).to(self.device)
        
        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.95,
                top_k=50,
                repetition_penalty=1.1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new generated text
        response = response[len(full_prompt):].strip()
        
        return response
    
    def _format_prompt(self, system_prompt: str, context: str, query: str) -> str:
        """Format the prompt with system instructions, context, and query."""
        prompt = f"{system_prompt}\n\n"
        
        if context:
            prompt += f"Context from your study materials:\n{context}\n\n"
        
        prompt += f"Query: {query}\n\nResponse:"
        
        return prompt
    
    def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
        """

        Generate responses for multiple prompts.

        

        Args:

            prompts: List of prompts

            **kwargs: Additional arguments for generate_response

            

        Returns:

            List of generated responses

        """
        responses = []
        for prompt in prompts:
            response = self.generate_response(prompt, **kwargs)
            responses.append(response)
        
        return responses