Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # ======================= | |
| # تنظیمات | |
| # ======================= | |
| MODEL_NAME = "erfanasghariyan/RADIOCAP200" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float32 # یا torch.bfloat16 اگر مدل bf16 است | |
| # ======================= | |
| # بارگذاری مدل و توکنایزر | |
| # ======================= | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| model.to(DEVICE) | |
| model.eval() | |
| # ======================= | |
| # ترنسفورم تصویر | |
| # ======================= | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # ======================= | |
| # تابع پردازش تصویر | |
| # ======================= | |
| def load_image(img): | |
| # اگر ورودی مسیر فایل بود، با PIL باز کن | |
| if isinstance(img, str): | |
| img = Image.open(img).convert("RGB") | |
| elif isinstance(img, Image.Image): | |
| img = img.convert("RGB") | |
| else: | |
| raise TypeError(f"Unexpected type {type(img)}") | |
| img_tensor = transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE) | |
| return img_tensor | |
| # ======================= | |
| # تابع پیشبینی | |
| # ======================= | |
| def predict(img): | |
| img_tensor = load_image(img) | |
| with torch.no_grad(): | |
| output_ids = model.generate(img_tensor, max_length=128) | |
| caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return caption | |
| # ======================= | |
| # رابط Gradio | |
| # ======================= | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="filepath"), # مسیر فایل به تابع داده میشود | |
| outputs="text", | |
| title="RADIOCAP200 - Radiology Captioning", | |
| description="Upload a radiology image and get a generated report/caption." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |