bombshelll commited on
Commit
227593e
·
1 Parent(s): 6f25734

Add hierarchical classification and captioning app

Browse files
Files changed (2) hide show
  1. app.py +83 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import VisionEncoderDecoderModel, AutoTokenizer, ViTFeatureExtractor, AutoImageProcessor, AutoModelForImageClassification
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # Load image captioning model
9
+ caption_model = VisionEncoderDecoderModel.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO").to(device)
10
+ tokenizer = AutoTokenizer.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
11
+ feature_extractor = ViTFeatureExtractor.from_pretrained("bombshelll/ViT_BioMedBert_Captioning_ROCO")
12
+
13
+ # Load classification models
14
+ def load_classifier(model_id):
15
+ processor = AutoImageProcessor.from_pretrained(model_id)
16
+ model = AutoModelForImageClassification.from_pretrained(model_id)
17
+ return processor, model
18
+
19
+ classifiers = {
20
+ "plane": load_classifier("bombshelll/swin-brain-plane-classification"),
21
+ "modality": load_classifier("bombshelll/swin-brain-modality-classification"),
22
+ "abnormality": load_classifier("bombshelll/swin-brain-abnormalities-classification"),
23
+ "tumor_type": load_classifier("bombshelll/swin-brain-tumor-type-classification")
24
+ }
25
+
26
+ # Inference functions
27
+ def classify_image(image):
28
+ results = {}
29
+ for name, (processor, model) in classifiers.items():
30
+ inputs = processor(image, return_tensors="pt").to(device)
31
+ with torch.no_grad():
32
+ logits = model(**inputs).logits
33
+ label = model.config.id2label[logits.argmax(-1).item()]
34
+ if name != "tumor_type" or results.get("abnormality") == "tumor":
35
+ results[name] = label
36
+ return results
37
+
38
+ def generate_captions(image, keywords):
39
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
40
+
41
+ # Without keywords
42
+ caption_model.eval()
43
+ with torch.no_grad():
44
+ output_ids = caption_model.generate(pixel_values, max_length=80)
45
+ caption1 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
+
47
+ # With keywords
48
+ prompt = " ".join(keywords)
49
+ prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
50
+ with torch.no_grad():
51
+ output_ids = caption_model.generate(
52
+ pixel_values,
53
+ decoder_input_ids=prompt_ids,
54
+ max_length=80,
55
+ num_beams=4,
56
+ no_repeat_ngram_size=3,
57
+ length_penalty=2.0
58
+ )
59
+ caption2 = tokenizer.decode(output_ids[0], skip_special_tokens=True)
60
+
61
+ return caption1, caption2
62
+
63
+ # Main app logic
64
+ def run_pipeline(image):
65
+ classification = classify_image(image)
66
+ keywords = list(classification.values())
67
+ caption1, caption2 = generate_captions(image, keywords)
68
+ return classification, caption1, caption2
69
+
70
+ # Gradio Interface
71
+ interface = gr.Interface(
72
+ fn=run_pipeline,
73
+ inputs=gr.Image(type="pil"),
74
+ outputs=[
75
+ gr.JSON(label="Classification Result"),
76
+ gr.Textbox(label="Caption without Keywords"),
77
+ gr.Textbox(label="Caption with Keywords")
78
+ ],
79
+ title="🧠 Brain Hierarchical Classification + Captioning",
80
+ description="Upload an MRI/CT brain image. The system will classify (plane, modality, abnormality, tumor) and generate two captions: one plain and one guided by the classification keywords."
81
+ )
82
+
83
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow