import torch from torch import nn from torchvision import transforms from PIL import Image import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ======================= # تنظیمات # ======================= MODEL_NAME = "erfanasghariyan/RADIOCAP200" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float32 # یا torch.bfloat16 اگر مدل bf16 است # ======================= # بارگذاری مدل و توکنایزر # ======================= tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) model.to(DEVICE) model.eval() # ======================= # ترنسفورم تصویر # ======================= transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ======================= # تابع پردازش تصویر # ======================= def load_image(img): # اگر ورودی مسیر فایل بود، با PIL باز کن if isinstance(img, str): img = Image.open(img).convert("RGB") elif isinstance(img, Image.Image): img = img.convert("RGB") else: raise TypeError(f"Unexpected type {type(img)}") img_tensor = transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE) return img_tensor # ======================= # تابع پیش‌بینی # ======================= def predict(img): img_tensor = load_image(img) with torch.no_grad(): output_ids = model.generate(img_tensor, max_length=128) caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) return caption # ======================= # رابط Gradio # ======================= iface = gr.Interface( fn=predict, inputs=gr.Image(type="filepath"), # مسیر فایل به تابع داده می‌شود outputs="text", title="RADIOCAP200 - Radiology Captioning", description="Upload a radiology image and get a generated report/caption." ) if __name__ == "__main__": iface.launch()