RADIOCAP200 / app.py
hackergeek's picture
Update app.py
c6845ce verified
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =======================
# تنظیمات
# =======================
MODEL_NAME = "erfanasghariyan/RADIOCAP200"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float32 # یا torch.bfloat16 اگر مدل bf16 است
# =======================
# بارگذاری مدل و توکنایزر
# =======================
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval()
# =======================
# ترنسفورم تصویر
# =======================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# =======================
# تابع پردازش تصویر
# =======================
def load_image(img):
# اگر ورودی مسیر فایل بود، با PIL باز کن
if isinstance(img, str):
img = Image.open(img).convert("RGB")
elif isinstance(img, Image.Image):
img = img.convert("RGB")
else:
raise TypeError(f"Unexpected type {type(img)}")
img_tensor = transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
return img_tensor
# =======================
# تابع پیش‌بینی
# =======================
def predict(img):
img_tensor = load_image(img)
with torch.no_grad():
output_ids = model.generate(img_tensor, max_length=128)
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return caption
# =======================
# رابط Gradio
# =======================
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="filepath"), # مسیر فایل به تابع داده می‌شود
outputs="text",
title="RADIOCAP200 - Radiology Captioning",
description="Upload a radiology image and get a generated report/caption."
)
if __name__ == "__main__":
iface.launch()