hackergeek commited on
Commit
c6845ce
·
verified ·
1 Parent(s): 8bbc742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -61
app.py CHANGED
@@ -3,82 +3,68 @@ from torch import nn
3
  from torchvision import transforms
4
  from PIL import Image
5
  import gradio as gr
6
- from transformers import AutoTokenizer
7
 
8
- # ===========================
9
- # تنظیمات دستگاه و dtype
10
- # ===========================
11
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
 
13
 
14
- # ===========================
15
- # مسیر مدل و tokenizer
16
- # ===========================
17
- CHECKPOINT_PATH = "checkpoints/epoch_04/model.pt" # مسیر دانلود شده در Space
18
- TOKENIZER_NAME = "bert-base-uncased" # یا مدل tokenizer مناسب شما
19
-
20
- tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
21
-
22
- # ===========================
23
- # تعریف مدل (مثال ساده)
24
- # ===========================
25
- # توجه: مدل واقعی خودت را اینجا قرار بده
26
- class DummyCaptionModel(nn.Module):
27
- def __init__(self):
28
- super().__init__()
29
- self.dummy = nn.Linear(10, 10)
30
-
31
- def forward(self, x, question=None):
32
- # خروجی فرضی
33
- if question:
34
- return "Answer to question: " + question
35
- return "Generated caption for the image"
36
-
37
- model = DummyCaptionModel()
38
- if torch.cuda.is_available():
39
- model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
40
  model.to(DEVICE)
41
  model.eval()
42
 
43
- # ===========================
44
- # Transform تصویر
45
- # ===========================
46
  transform = transforms.Compose([
47
  transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
- # transforms.Normalize(mean=[0.485, 0.456, 0.406],
50
- # std=[0.229, 0.224, 0.225])
51
  ])
52
 
53
- # ===========================
54
- # تابع بارگذاری تصویر
55
- # ===========================
56
- def load_image(img: Image.Image):
57
- """تبدیل PIL image به Tensor"""
58
- return transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
 
 
 
 
 
 
 
 
59
 
60
- # ===========================
61
- # تابع اصلی پیش‌بینی
62
- # ===========================
63
- def predict(img: Image.Image, question: str = ""):
64
  img_tensor = load_image(img)
65
- # اگر سوال خالی بود کپشن تولید کن، وگرنه VQA
66
- output_text = model(img_tensor, question.strip() or None)
67
- return output_text
 
68
 
69
- # ===========================
70
- # Interface گریدیو
71
- # ===========================
72
  iface = gr.Interface(
73
  fn=predict,
74
- inputs=[
75
- gr.Image(type="pil", label="Upload Radiology Image"),
76
- gr.Textbox(label="Optional Question (for VQA)", placeholder="Ask a question or leave empty for caption")
77
- ],
78
- outputs=gr.Textbox(label="Output"),
79
- title="RADIOCAP200: Radiology Caption + VQA",
80
- description="Upload a radiology image and optionally ask a question. If the question is empty, model generates a caption. Otherwise, it answers the question."
81
  )
82
 
83
  if __name__ == "__main__":
84
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
3
  from torchvision import transforms
4
  from PIL import Image
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
+ # =======================
9
+ # تنظیمات
10
+ # =======================
11
+ MODEL_NAME = "erfanasghariyan/RADIOCAP200"
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ DTYPE = torch.float32 # یا torch.bfloat16 اگر مدل bf16 است
14
 
15
+ # =======================
16
+ # بارگذاری مدل و توکنایزر
17
+ # =======================
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  model.to(DEVICE)
21
  model.eval()
22
 
23
+ # =======================
24
+ # ترنسفورم تصویر
25
+ # =======================
26
  transform = transforms.Compose([
27
  transforms.Resize((224, 224)),
28
  transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406],
30
+ [0.229, 0.224, 0.225])
31
  ])
32
 
33
+ # =======================
34
+ # تابع پردازش تصویر
35
+ # =======================
36
+ def load_image(img):
37
+ # اگر ورودی مسیر فایل بود، با PIL باز کن
38
+ if isinstance(img, str):
39
+ img = Image.open(img).convert("RGB")
40
+ elif isinstance(img, Image.Image):
41
+ img = img.convert("RGB")
42
+ else:
43
+ raise TypeError(f"Unexpected type {type(img)}")
44
+
45
+ img_tensor = transform(img).unsqueeze(0).to(DEVICE, dtype=DTYPE)
46
+ return img_tensor
47
 
48
+ # =======================
49
+ # تابع پیش‌بینی
50
+ # =======================
51
+ def predict(img):
52
  img_tensor = load_image(img)
53
+ with torch.no_grad():
54
+ output_ids = model.generate(img_tensor, max_length=128)
55
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
56
+ return caption
57
 
58
+ # =======================
59
+ # رابط Gradio
60
+ # =======================
61
  iface = gr.Interface(
62
  fn=predict,
63
+ inputs=gr.Image(type="filepath"), # مسیر فایل به تابع داده می‌شود
64
+ outputs="text",
65
+ title="RADIOCAP200 - Radiology Captioning",
66
+ description="Upload a radiology image and get a generated report/caption."
 
 
 
67
  )
68
 
69
  if __name__ == "__main__":
70
+ iface.launch()