| |
| import torch |
| import logging |
| from transformers import BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel, PeftConfig |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='inference.log') |
| logger = logging.getLogger(__name__) |
|
|
| |
| class OmniPhiModel(torch.nn.Module): |
| def __init__(self, phi_model_name, clip_model_name): |
| super().__init__() |
| self.phi = None |
| self.clip_model_name = clip_model_name |
|
|
| def forward(self, text_input_ids, attention_mask, image_embedding): |
| return torch.zeros_like(text_input_ids, dtype=torch.float32) |
|
|
| |
| def load_blip(blip_model_name, device, torch_dtype): |
| try: |
| blip_processor = BlipProcessor.from_pretrained(blip_model_name) |
| |
| blip_model = BlipForConditionalGeneration.from_pretrained( |
| blip_model_name, |
| torch_dtype=torch_dtype, |
| use_safetensors=True |
| ).to(device).eval() |
| logger.info("BLIP model and processor loaded successfully with safetensors") |
| return blip_model, blip_processor |
| except Exception as e: |
| logger.error(f"Failed to load BLIP: {e}") |
| raise RuntimeError(f"BLIP failed to load: {e}") |
|
|
| def load_clip(clip_model_name, device, torch_dtype): |
| try: |
| clip_model = CLIPModel.from_pretrained( |
| clip_model_name, |
| use_safetensors=True |
| ).to(device).eval() |
| clip_processor = CLIPProcessor.from_pretrained(clip_model_name, use_fast=True) |
| logger.info("CLIP model and processor loaded successfully with safetensors") |
| return clip_model, clip_processor |
| except Exception as e: |
| logger.error(f"Failed to load CLIP: {e}") |
| raise RuntimeError(f"CLIP failed to load: {e}") |
|
|
| def load_omniphi(peft_model_id, phi_model_name, clip_model_name, device): |
| try: |
| config = PeftConfig.from_pretrained(peft_model_id) |
| base_model = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name_or_path, |
| return_dict=True, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=False, |
| use_safetensors=True |
| ) |
| peft_model = PeftModel.from_pretrained(base_model, peft_model_id).to(device).eval() |
| tokenizer = AutoTokenizer.from_pretrained(peft_model_id) |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]", "[SEP]"], "pad_token": "<pad>"}) |
| model = OmniPhiModel(phi_model_name=config.base_model_name_or_path, clip_model_name=clip_model_name).to(device) |
| model.phi = peft_model |
| logger.info("OmniPhi model and tokenizer loaded successfully with safetensors") |
| return model, tokenizer |
| except Exception as e: |
| logger.error(f"Failed to load OmniPhi model: {e}") |
| return None, None |