Spaces:
Build error
Build error
| import os | |
| import torch | |
| from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
| from PIL import Image | |
| import gradio as gr | |
| # ----------------------------------------------------------------------------- | |
| # Load HF token from environment | |
| # ----------------------------------------------------------------------------- | |
| HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set") | |
| # ----------------------------------------------------------------------------- | |
| # 1) GPU inference function | |
| # ----------------------------------------------------------------------------- | |
| def run_inference_on_gpu( | |
| model_id: str, | |
| image: Image.Image, | |
| prompt: str = "caption", | |
| max_new_tokens: int = 100 | |
| ) -> str: | |
| # ensure CUDA is available | |
| assert torch.cuda.is_available(), "CUDA not available—check your PyTorch installation!" | |
| device = torch.device("cuda") | |
| dtype = torch.float16 | |
| # load tokenizer + model onto GPU with explicit token | |
| processor = AutoProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN) | |
| model = PaliGemmaForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map=None, | |
| use_auth_token=HF_TOKEN | |
| ).to(device).eval() | |
| # build multimodal prompt | |
| image_tokens = "<image>" | |
| multimodal_prompt = f"{image_tokens} {prompt}" | |
| # prepare inputs | |
| inputs = processor( | |
| text=multimodal_prompt, | |
| images=[image], | |
| padding="longest", | |
| return_tensors="pt", | |
| do_convert_rgb=True, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=3, | |
| do_sample=False, | |
| ) | |
| # decode | |
| return processor.decode(outputs[0].cpu(), skip_special_tokens=True) | |
| # ----------------------------------------------------------------------------- | |
| # 2) Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "mychen76/paligemma-3b-mix-448-med_30k-ct-brain" | |
| def caption_fn(image, prompt, max_tokens): | |
| """ | |
| Gradio callback: takes a PIL image, a text prompt, and | |
| max tokens → returns the generated caption. | |
| """ | |
| return run_inference_on_gpu( | |
| model_id=MODEL_ID, | |
| image=image, | |
| prompt=prompt, | |
| max_new_tokens=max_tokens, | |
| ) | |
| demo = gr.Interface( | |
| fn=caption_fn, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload CT Scan"), | |
| gr.Textbox( | |
| value="What do you see in this CT scan?", | |
| label="Prompt" | |
| ), | |
| gr.Slider( | |
| minimum=10, maximum=300, step=10, value=100, | |
| label="Max New Tokens" | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="Model Caption"), | |
| title="PaliGemma CT-Scan Captioning", | |
| description=( | |
| "Upload a brain CT scan (or any image), write a short prompt, " | |
| "and let the PaliGemma model describe what it sees." | |
| ), | |
| allow_flagging="never", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=False, show_api=False) | |