Spaces:
Sleeping
Sleeping
File size: 5,852 Bytes
3736c33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from typing import List, Dict, Optional
import config
class ModelInference:
"""Handle model loading and inference for text generation."""
def __init__(self, model_name: str = None, use_4bit: bool = True):
"""
Initialize the model for inference.
RAG Mode: Uses pre-trained model directly (no training needed!).
Args:
model_name: Name or path of the model (uses pre-trained by default)
use_4bit: Whether to use 4-bit quantization for efficiency
"""
# Use pre-trained model if specified, otherwise check for fine-tuned model
if config.USE_PRETRAINED or not Path(config.MODEL_PATH).exists():
self.model_name = model_name or config.MODEL_NAME
else:
self.model_name = model_name or config.MODEL_PATH
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {self.model_name}")
print(f"Device: {self.device}")
# Configure quantization for efficiency
if use_4bit and self.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
def generate_response(
self,
prompt: str,
context: str = "",
use_case: str = "explanation",
temperature: float = None,
max_tokens: int = None
) -> str:
"""
Generate a response based on the prompt and context.
Args:
prompt: User query
context: Retrieved context from documents
use_case: Type of response (explanation, summary, qa, notes)
temperature: Sampling temperature
max_tokens: Maximum number of tokens to generate
Returns:
Generated text response
"""
temperature = temperature or config.TEMPERATURE
max_tokens = max_tokens or config.MAX_TOKENS
# Create system prompt based on use case
system_prompts = {
"explanation": "You are an expert tutor. Provide detailed, clear explanations of concepts based on the given context.",
"summary": "You are a summarization expert. Create concise, well-structured summaries of the provided content.",
"qa": "You are a knowledgeable assistant. Answer questions accurately based on the given context.",
"notes": "You are a study notes specialist. Create well-organized, structured study notes from the content."
}
system_prompt = system_prompts.get(use_case, system_prompts["explanation"])
# Format the full prompt
full_prompt = self._format_prompt(system_prompt, context, prompt)
# Tokenize
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
top_k=50,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new generated text
response = response[len(full_prompt):].strip()
return response
def _format_prompt(self, system_prompt: str, context: str, query: str) -> str:
"""Format the prompt with system instructions, context, and query."""
prompt = f"{system_prompt}\n\n"
if context:
prompt += f"Context from your study materials:\n{context}\n\n"
prompt += f"Query: {query}\n\nResponse:"
return prompt
def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
"""
Generate responses for multiple prompts.
Args:
prompts: List of prompts
**kwargs: Additional arguments for generate_response
Returns:
List of generated responses
"""
responses = []
for prompt in prompts:
response = self.generate_response(prompt, **kwargs)
responses.append(response)
return responses
|