Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from huggingface_hub import login | |
| from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
| import spaces | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token, add_to_git_credential=True) | |
| class PaliGemmaModel: | |
| def __init__(self): | |
| self.model_id = "google/paligemma-3b-mix-448" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device) | |
| self.processor = PaliGemmaProcessor.from_pretrained(self.model_id) | |
| def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str: | |
| inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.device) | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False | |
| ) | |
| result = self.processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return result[0][len(text):].lstrip("\n") |