import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from typing import List, Dict, Optional import config class ModelInference: """Handle model loading and inference for text generation.""" def __init__(self, model_name: str = None, use_4bit: bool = True): """ Initialize the model for inference. RAG Mode: Uses pre-trained model directly (no training needed!). Args: model_name: Name or path of the model (uses pre-trained by default) use_4bit: Whether to use 4-bit quantization for efficiency """ # Use pre-trained model if specified, otherwise check for fine-tuned model if config.USE_PRETRAINED or not Path(config.MODEL_PATH).exists(): self.model_name = model_name or config.MODEL_NAME else: self.model_name = model_name or config.MODEL_PATH self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model: {self.model_name}") print(f"Device: {self.device}") # Configure quantization for efficiency if use_4bit and self.device == "cuda": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) else: self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model.eval() def generate_response( self, prompt: str, context: str = "", use_case: str = "explanation", temperature: float = None, max_tokens: int = None ) -> str: """ Generate a response based on the prompt and context. Args: prompt: User query context: Retrieved context from documents use_case: Type of response (explanation, summary, qa, notes) temperature: Sampling temperature max_tokens: Maximum number of tokens to generate Returns: Generated text response """ temperature = temperature or config.TEMPERATURE max_tokens = max_tokens or config.MAX_TOKENS # Create system prompt based on use case system_prompts = { "explanation": "You are an expert tutor. Provide detailed, clear explanations of concepts based on the given context.", "summary": "You are a summarization expert. Create concise, well-structured summaries of the provided content.", "qa": "You are a knowledgeable assistant. Answer questions accurately based on the given context.", "notes": "You are a study notes specialist. Create well-organized, structured study notes from the content." } system_prompt = system_prompts.get(use_case, system_prompts["explanation"]) # Format the full prompt full_prompt = self._format_prompt(system_prompt, context, prompt) # Tokenize inputs = self.tokenizer( full_prompt, return_tensors="pt", truncation=True, max_length=2048 ).to(self.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the new generated text response = response[len(full_prompt):].strip() return response def _format_prompt(self, system_prompt: str, context: str, query: str) -> str: """Format the prompt with system instructions, context, and query.""" prompt = f"{system_prompt}\n\n" if context: prompt += f"Context from your study materials:\n{context}\n\n" prompt += f"Query: {query}\n\nResponse:" return prompt def batch_generate(self, prompts: List[str], **kwargs) -> List[str]: """ Generate responses for multiple prompts. Args: prompts: List of prompts **kwargs: Additional arguments for generate_response Returns: List of generated responses """ responses = [] for prompt in prompts: response = self.generate_response(prompt, **kwargs) responses.append(response) return responses