OmniPhi / models.py
SeemG's picture
Update models.py
4aeeadd verified
# models.py
import torch
import logging
from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='inference.log')
logger = logging.getLogger(__name__)
# Dummy OmniPhiModel class (replace with your actual implementation)
class OmniPhiModel(torch.nn.Module):
def __init__(self, phi_model_name, clip_model_name):
super().__init__()
self.phi = None
self.clip_model_name = clip_model_name
def forward(self, text_input_ids, attention_mask, image_embedding):
return torch.zeros_like(text_input_ids, dtype=torch.float32)
# Model loading functions
def load_blip(blip_model_name, device, torch_dtype):
try:
blip_processor = BlipProcessor.from_pretrained(blip_model_name)
# Use safetensors to avoid torch.load vulnerability
blip_model = BlipForConditionalGeneration.from_pretrained(
blip_model_name,
torch_dtype=torch_dtype,
use_safetensors=True # Force safetensors loading
).to(device).eval()
logger.info("BLIP model and processor loaded successfully with safetensors")
return blip_model, blip_processor
except Exception as e:
logger.error(f"Failed to load BLIP: {e}")
raise RuntimeError(f"BLIP failed to load: {e}")
def load_clip(clip_model_name, device, torch_dtype):
try:
clip_model = CLIPModel.from_pretrained(
clip_model_name,
use_safetensors=True # Also use safetensors for CLIP
).to(device).eval()
clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True)
logger.info("CLIP model and processor loaded successfully with safetensors")
return clip_model, clip_processor
except Exception as e:
logger.error(f"Failed to load CLIP: {e}")
raise RuntimeError(f"CLIP failed to load: {e}")
def load_omniphi(peft_model_id, phi_model_name, clip_model_name, device):
try:
config = PeftConfig.from_pretrained(peft_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=False,
use_safetensors=True # Use safetensors for OmniPhi base model
)
peft_model = PeftModel.from_pretrained(base_model, peft_model_id).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]", "[SEP]"], "pad_token": "<pad>"})
model = OmniPhiModel(phi_model_name=config.base_model_name_or_path, clip_model_name=clip_model_name).to(device)
model.phi = peft_model
logger.info("OmniPhi model and tokenizer loaded successfully with safetensors")
return model, tokenizer
except Exception as e:
logger.error(f"Failed to load OmniPhi model: {e}")
return None, None