MedSpace / src /generation /llm_wrapper.py
kbsss's picture
Upload folder using huggingface_hub
f373e2b verified
Raw
History Blame Contribute Delete
7.87 kB
"""
LLM wrapper for medical question answering.
"""
import torch
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import gc
# Custom exceptions for granular error handling
class LLMError(Exception):
"""Base exception for LLM errors."""
pass
class ModelNotFoundError(LLMError):
"""Raised when model cannot be found/downloaded."""
pass
class GPUOutOfMemoryError(LLMError):
"""Raised when GPU runs out of memory."""
pass
class GenerationError(LLMError):
"""Raised when generation fails."""
pass
@dataclass
class GenerationResult:
"""Result from LLM generation."""
response: str
input_tokens: int
generated_tokens: int
probabilities: Optional[List[float]] = None
class MedicalLLM:
"""
Medical domain LLM wrapper.
Supports: BioMistral, TinyLlama, Mistral, Llama
"""
SUPPORTED_MODELS = {
"biomistral": "BioMistral/BioMistral-7B",
"tinyllama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"mistral": "mistralai/Mistral-7B-Instruct-v0.2",
}
def __init__(
self,
model_name: str = "tinyllama",
device: Optional[str] = None,
load_in_4bit: bool = True,
max_memory: Optional[Dict] = None,
adapter_path: Optional[str] = None
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.adapter_path = adapter_path
# Get model path
if model_name in self.SUPPORTED_MODELS:
model_path = self.SUPPORTED_MODELS[model_name]
else:
model_path = model_name
self.model_name = model_name
print(f"🔄 Loading LLM: {model_path} on {self.device}")
# Quantization config for 4-bit loading
if load_in_4bit and self.device == "cuda":
try:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
except Exception:
quantization_config = None
print("⚠️ BitsAndBytes not available, loading without quantization")
else:
quantization_config = None
# Load tokenizer - from adapter if available, otherwise base model
tokenizer_path = adapter_path if adapter_path else model_path
try:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True
)
except OSError as e:
if "not found" in str(e).lower() or "does not appear" in str(e).lower():
raise ModelNotFoundError(f"Model '{tokenizer_path}' not found. Check the model name or internet connection.")
raise
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with granular error handling
try:
if quantization_config and self.device == "cuda":
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True
)
if self.device == "cpu":
self.model = self.model.to(self.device)
# Load PEFT adapter if specified
if adapter_path:
try:
from peft import PeftModel
print(f"🔄 Loading PEFT adapter from {adapter_path}")
self.model = PeftModel.from_pretrained(self.model, adapter_path)
print(f"✅ Adapter loaded successfully")
except ImportError:
print("⚠️ PEFT not installed, using base model without adapter")
except Exception as e:
print(f"⚠️ Failed to load adapter: {e}, using base model")
self.model.eval()
print(f"✅ Model loaded successfully")
except torch.cuda.OutOfMemoryError as e:
raise GPUOutOfMemoryError(
f"GPU out of memory loading '{model_path}'. "
"Try: 1) Using load_in_4bit=True, 2) Using a smaller model like 'tinyllama', "
"3) Reducing batch size, 4) Using device='cpu'"
) from e
except OSError as e:
if "not found" in str(e).lower() or "does not appear" in str(e).lower():
raise ModelNotFoundError(f"Model '{model_path}' not found: {e}")
raise
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
return_probabilities: bool = False
) -> GenerationResult:
"""Generate response from the LLM."""
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(self.model.device)
input_length = inputs.input_ids.shape[1]
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else 1.0,
top_p=top_p if do_sample else 1.0,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
output_scores=return_probabilities,
return_dict_in_generate=True
)
# Decode response
generated_ids = outputs.sequences[0][input_length:]
response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
# Get probabilities if requested
probabilities = None
if return_probabilities and hasattr(outputs, 'scores'):
probs = []
for score in outputs.scores:
prob = torch.softmax(score[0], dim=-1)
probs.append(prob.max().item())
probabilities = probs
return GenerationResult(
response=response.strip(),
input_tokens=input_length,
generated_tokens=len(generated_ids),
probabilities=probabilities
)
def generate_with_context(
self,
question: str,
context: str,
max_new_tokens: int = 512
) -> GenerationResult:
"""Generate response with context (for RAG)."""
prompt = f"""You are a helpful medical assistant. Use the following context to answer the question. If you're unsure, say so.
Context:
{context}
Question: {question}
Answer:"""
return self.generate(prompt, max_new_tokens=max_new_tokens)
def cleanup(self):
"""Free GPU memory."""
del self.model
del self.tokenizer
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()