| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| MODEL_ID = "google/medgemma-2b-it" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| token=HF_TOKEN | |
| ) | |