|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- jeffrey423/ToothXpert.MM-OPG-Annotations |
|
|
language: |
|
|
- en |
|
|
tags: |
|
|
- dental |
|
|
- medical |
|
|
- multimodal |
|
|
- vision-language |
|
|
- llava |
|
|
- clip |
|
|
- sam |
|
|
- lora |
|
|
- orthopantomography |
|
|
- opg |
|
|
- x-ray |
|
|
- diagnosis |
|
|
base_model: liuhaotian/llava-v1.5-7b |
|
|
pipeline_tag: image-text-to-text |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# ToothXpert Model |
|
|
|
|
|
ToothXpert is a multimodal AI model for comprehensive dental X-ray (OPG) analysis, combining vision and language understanding for automatic diagnosis and condition detection. |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision transformers |
|
|
pip install opencv-python einops peft medpy |
|
|
pip install "numpy<2.0" # Important for compatibility |
|
|
``` |
|
|
|
|
|
### Download Model |
|
|
|
|
|
```python |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
model_path = snapshot_download( |
|
|
repo_id='jeffrey423/ToothXpert', |
|
|
local_dir='./ToothXpert_pretrained' |
|
|
) |
|
|
``` |
|
|
|
|
|
### Simple Inference |
|
|
|
|
|
```python |
|
|
import cv2 |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, CLIPImageProcessor |
|
|
from model.ToothXpert_MOE import ToothXpertForCausalLMMOE |
|
|
from model.llava import conversation as conversation_lib |
|
|
from model.llava.mm_utils import tokenizer_image_token |
|
|
from model.segment_anything.utils.transforms import ResizeLongestSide |
|
|
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, |
|
|
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) |
|
|
|
|
|
# Preprocessing function |
|
|
def preprocess(x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), |
|
|
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024): |
|
|
x = (x - pixel_mean) / pixel_std |
|
|
h, w = x.shape[-2:] |
|
|
padh = img_size - h |
|
|
padw = img_size - w |
|
|
x = F.pad(x, (0, padw, 0, padh)) |
|
|
return x |
|
|
|
|
|
# Load model |
|
|
model_path = "./ToothXpert_pretrained" |
|
|
device = "cuda:0" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_path, |
|
|
model_max_length=512, |
|
|
padding_side="right", |
|
|
use_fast=False, |
|
|
) |
|
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
tokenizer.add_tokens("[SEG]") |
|
|
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
|
|
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) |
|
|
|
|
|
moe_lora_args = { |
|
|
"lora_r": 8, |
|
|
"lora_alpha": 16, |
|
|
"lora_dropout": 0.05, |
|
|
"lora_target_modules": "q_proj,v_proj", |
|
|
"moe_lora": False, |
|
|
"expert_num": 3, |
|
|
"guide": True, |
|
|
"guide_mode": "smmulsm", |
|
|
"vocab_size": len(tokenizer), |
|
|
} |
|
|
|
|
|
model = ToothXpertForCausalLMMOE.from_pretrained( |
|
|
model_path, |
|
|
low_cpu_mem_usage=True, |
|
|
vision_tower="openai/clip-vit-large-patch14", |
|
|
seg_token_idx=seg_token_idx, |
|
|
torch_dtype=torch.bfloat16, |
|
|
train_mask_decoder=True, |
|
|
out_dim=256, |
|
|
moe_lora_args=moe_lora_args, |
|
|
) |
|
|
|
|
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
|
model.config.bos_token_id = tokenizer.bos_token_id |
|
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
model.get_model().initialize_vision_modules(model.get_model().config) |
|
|
vision_tower = model.get_model().get_vision_tower() |
|
|
vision_tower.to(dtype=torch.bfloat16, device=device) |
|
|
|
|
|
model = model.bfloat16().to(device) |
|
|
model.eval() |
|
|
|
|
|
# Load and process image |
|
|
image_path = "your_dental_xray.png" |
|
|
image_np = cv2.imread(image_path) |
|
|
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
|
|
original_size_list = [image_np.shape[:2]] |
|
|
|
|
|
clip_image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") |
|
|
transform = ResizeLongestSide(1024) |
|
|
|
|
|
image_clip = ( |
|
|
clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0] |
|
|
.unsqueeze(0).to(device).bfloat16() |
|
|
) |
|
|
|
|
|
image = transform.apply_image(image_np) |
|
|
resize_list = [image.shape[:2]] |
|
|
image = ( |
|
|
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) |
|
|
.unsqueeze(0).to(device).bfloat16() |
|
|
) |
|
|
|
|
|
# Prepare prompt |
|
|
question = "Can you describe the image for me?" |
|
|
conv = conversation_lib.conv_templates["llava_v1"].copy() |
|
|
conv.messages = [] |
|
|
prompt = DEFAULT_IMAGE_TOKEN + "\n" + question |
|
|
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, |
|
|
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN) |
|
|
|
|
|
conv.append_message(conv.roles[0], prompt) |
|
|
conv.append_message(conv.roles[1], "") |
|
|
prompt = conv.get_prompt() |
|
|
|
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
|
input_ids = input_ids.unsqueeze(0).to(device) |
|
|
|
|
|
# Run inference |
|
|
with torch.no_grad(): |
|
|
output_ids, pred_masks = model.evaluate( |
|
|
image_clip, |
|
|
image, |
|
|
input_ids, |
|
|
resize_list, |
|
|
original_size_list, |
|
|
max_new_tokens=512, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] |
|
|
text_output = tokenizer.decode(output_ids, skip_special_tokens=False) |
|
|
text_output = text_output.split('ASSISTANT:')[-1].replace('</s>', '').strip() |
|
|
|
|
|
print(f"Question: {question}") |
|
|
print(f"Answer: {text_output}") |
|
|
``` |
|
|
|
|
|
## Example Questions |
|
|
|
|
|
**General Description:** |
|
|
- "Can you describe the image for me?" |
|
|
|
|
|
**Specific Conditions:** |
|
|
- "Is there any amalgam restorations in the image?" |
|
|
- "Any R/L suggestive of caries present?" |
|
|
- "Is there any dental implant present?" |
|
|
- "Is there any root canal treated teeth?" |
|
|
|
|
|
## Supported Conditions |
|
|
|
|
|
ToothXpert can detect 11 dental conditions: |
|
|
1. Amalgam restorations |
|
|
2. Caries (R/L) |
|
|
3. Crestal bone loss (mandible) |
|
|
4. Crestal bone loss (maxillary) |
|
|
5. Implant-supported bridge |
|
|
6. Dental implant |
|
|
7. Metallic/non-metallic post |
|
|
8. Non-metallic restorations |
|
|
9. Periapical radiolucency |
|
|
10. Root canal treated teeth |
|
|
11. Tooth-supported bridge |
|
|
|
|
|
## Requirements |
|
|
|
|
|
- **GPU**: NVIDIA GPU with at least 16GB VRAM |
|
|
- **Python**: 3.11 (recommended) |
|
|
- **CUDA**: 12.1 or compatible |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Base Model**: LLaVA-1.5-7B |
|
|
- **Vision Encoder**: CLIP ViT-L/14 |
|
|
- **Segmentation**: SAM (Segment Anything Model) ViT-H |
|
|
- **Adaptation**: Guided Mixture of LoRA Experts (G-MoLE) |
|
|
- **Model Size**: ~15GB |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use ToothXpert in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{liu2026toothxpert, |
|
|
title={Developing and Evaluating Multimodal Large Language Model for Orthopantomography Analysis to Support Clinical Dentistry}, |
|
|
author={Liu, Xinyu and Hung, Kuo Feng and Yu, Weihao and Ng, Ray Anthony W T and Li, Wuyang and Niu, Tianye and Chen, Hui and Yuan, Yixuan}, |
|
|
journal={Cell Reports Medicine}, |
|
|
year={2026} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Links |
|
|
|
|
|
- **GitHub Repository**: [CUHK-AIM-Group/ToothXpert](https://github.com/CUHK-AIM-Group/ToothXpert) |
|
|
- **Dataset**: [jeffrey423/ToothXpert.MM-OPG-Annotations](https://huggingface.co/datasets/jeffrey423/ToothXpert.MM-OPG-Annotations) |
|
|
|
|
|
## License |
|
|
|
|
|
Apache License 2.0 |
|
|
|