hackergeek commited on
Commit
fdad5ba
·
verified ·
1 Parent(s): 92bf796

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -34
app.py CHANGED
@@ -1,61 +1,65 @@
1
  import gradio as gr
2
- from transformers import AutoImageProcessor, AutoModelForImageClassification
3
  import torch
4
  from PIL import Image
5
- import numpy as np
6
 
7
- # بارگذاری مدل و processor یک بار موقع شروع
8
- MODEL_ID = "erfanansaghariyan/mobilew-v11-convnext-tiny-6layer-radiimagen"
9
 
10
- processor = AutoImageProcessor.from_pretrained(MODEL_ID)
11
- model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
12
  model.eval()
13
 
14
- # اگر GPU داشتی فعالش کن
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
17
 
18
- def predict(image: Image.Image):
19
- # پیش‌پردازش تصویر
20
- inputs = processor(images=image, return_tensors="pt")
21
- inputs = {k: v.to(device) for k, v in inputs.items()}
22
 
23
- # inference
24
  with torch.no_grad():
25
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
26
 
27
- logits = outputs.logits
28
- probabilities = torch.softmax(logits, dim=-1).cpu().numpy()[0]
29
 
30
- # گرفتن top-k پیش‌بینی‌ها
31
- top_k = 5
32
- top_indices = np.argsort(probabilities)[-top_k:][::-1]
33
- labels = [model.config.id2label[idx] for idx in top_indices]
34
- probs = [float(probabilities[idx]) for idx in top_indices]
35
 
36
- # خروجی به صورت دیکشنری برای Gradio
37
- result = {label: prob for label, prob in zip(labels, probs)}
38
-
39
- return result, image # تصویر + احتمالات
40
 
41
- # رابط کاربری Gradio
42
  demo = gr.Interface(
43
- fn=predict,
44
- inputs=gr.Image(type="pil", label="تصویر رادیولوژی آپلود کن"),
 
 
 
45
  outputs=[
46
- gr.Label(label="احتمالات کلاس‌ها (Top 5)"),
47
  gr.Image(label="تصویر ورودی")
48
  ],
49
- title="MobileWev-v11 ConvNeXt Tiny - Radiology Image Analysis",
50
  description=(
51
- "این مدل سبک (ConvNeXt-tiny با 6 لایه) برای تحلیل تصاویر پزشکی/رادیولوژی فاین‌تیون شده است.\n"
52
- "تصویر را آپلود کنید تا احتمالات کلاس‌های تشخیص را ببینید."
 
 
53
  ),
54
  examples=[
55
- # می‌تونی چند تا مثال تصویر بذاری (آپلود کن توی Space)
56
- # مثلاً: ["example_chest_xray.jpg"]
57
  ],
58
- cache_examples=False,
59
  )
60
 
61
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM # یا AutoModelForVision2Seq اگر vision2seq باشه
3
  import torch
4
  from PIL import Image
 
5
 
6
+ # repo_id دقیق
7
+ MODEL_ID = "erfanasghariyan/mobilew-v11-convnext-tiny-6layer-radimagenet"
8
 
9
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
10
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID) # اگر causalLM نباشه، عوض کن به AutoModelForVision2Seq
11
  model.eval()
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
+ def generate_report(image: Image.Image, prompt: str = "Describe this radiology image in detail:"):
17
+ # پردازش تصویر + متن prompt
18
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
 
19
 
20
+ # generation
21
  with torch.no_grad():
22
+ outputs = model.generate(
23
+ **inputs,
24
+ max_new_tokens=200,
25
+ do_sample=True,
26
+ temperature=0.7,
27
+ top_p=0.9,
28
+ num_beams=4, # اگر deterministic می‌خوای beam search
29
+ repetition_penalty=1.2
30
+ )
31
 
32
+ # decode خروجی
33
+ generated_text = processor.decode(outputs[0], skip_special_tokens=True)
34
 
35
+ # اگر prompt در خروجی تکرار شده، تمیز کن
36
+ if generated_text.startswith(prompt):
37
+ generated_text = generated_text[len(prompt):].strip()
 
 
38
 
39
+ return generated_text, image
 
 
 
40
 
 
41
  demo = gr.Interface(
42
+ fn=generate_report,
43
+ inputs=[
44
+ gr.Image(type="pil", label="تصویر رادیولوژی آپلود کن (X-ray, CT, MRI و ...)"),
45
+ gr.Textbox(label="پرامپت (اختیاری)", value="Generate a detailed radiology report for this image:")
46
+ ],
47
  outputs=[
48
+ gr.Textbox(label="گزارش / توصیف تولید شده"),
49
  gr.Image(label="تصویر ورودی")
50
  ],
51
+ title="MobileW-v11 ConvNeXt Tiny + 6-Layer Decoder for Radiology",
52
  description=(
53
+ "مدل سبک multimodal برای تحلیل تصاویر پزشکی.\n"
54
+ "encoder: ConvNeXt-Tiny (فریز شده)\n"
55
+ "decoder: 6 لایه برای تولید متن بهتر\n"
56
+ "مثال پرامپت‌ها: 'Describe findings', 'Write a radiology report', 'Is there pneumonia?'"
57
  ),
58
  examples=[
59
+ # اگر مثال تصویر داری، آپلود کن به Space و مسیر بده
60
+ # [ "example_chest_xray.jpg", "Write a structured report" ]
61
  ],
62
+ allow_flagging="never"
63
  )
64
 
65
  if __name__ == "__main__":