hssling commited on
Commit
ccb8f07
·
0 Parent(s):

Initial commit of model training and inference backend

Browse files
Files changed (3) hide show
  1. app.py +91 -0
  2. requirements.txt +9 -0
  3. train_multimodal.py +115 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ from PIL import Image
5
+
6
+ # 1. HuggingFace Space Deployment Settings
7
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" # Base model
8
+ # To use your fine-tuned model from Kaggle:
9
+ # 1. model.push_to_hub("your-name/med-qwen-vl-adapter")
10
+ # 2. Add adapter load here for PEFT
11
+ ADAPTER_ID = None
12
+
13
+ # Initialize Model and Processor globally
14
+ print("Starting App Engine...")
15
+ print(f"Loading {MODEL_ID}...")
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
19
+ model = AutoModelForVision2Seq.from_pretrained(
20
+ MODEL_ID,
21
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
22
+ device_map="auto"
23
+ )
24
+
25
+ if ADAPTER_ID:
26
+ print(f"Loading custom fine-tuned LoRA weights: {ADAPTER_ID}")
27
+ model.load_adapter(ADAPTER_ID)
28
+
29
+ # 2. Main API Function called by our Next App
30
+ def diagnose_api(history: str, examination: str, image: Image.Image = None, audio_path: str = None):
31
+ try:
32
+ if image is None:
33
+ # Fallback if no image is passed
34
+ return "Error: Qwen-VL requires at least one image/diagnostic input to function accurately."
35
+
36
+ # Re-construct the specific structured prompt our diagnostic copilot uses
37
+ system_prompt = "You are a highly advanced Multi-Modal Diagnostic Co-Pilot Medical AI. Provide ## Integrated Analysis, ## Decision Making, and ## Management & Treatment Plan."
38
+ user_prompt = f"History: {history}\nExamination: {examination}\nAnalyze the provided scan and history."
39
+
40
+ messages = [
41
+ {"role": "system", "content": system_prompt},
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"type": "image"},
46
+ {"type": "text", "text": user_prompt}
47
+ ]
48
+ }
49
+ ]
50
+
51
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
52
+
53
+ inputs = processor(
54
+ text=[text_input],
55
+ images=[image],
56
+ padding=True,
57
+ return_tensors="pt"
58
+ ).to(device)
59
+
60
+ with torch.no_grad():
61
+ generated_ids = model.generate(**inputs, max_new_tokens=1024, temperature=0.2)
62
+
63
+ generated_ids_trimmed = [
64
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
65
+ ]
66
+
67
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
68
+
69
+ return output_text
70
+
71
+ except Exception as e:
72
+ return f"Model Error: {str(e)}"
73
+
74
+ # 3. Create the Gradio interface
75
+ # This acts as the visual UI for the HF Space, but more importantly,
76
+ # exposes an API endpoint via `/api/predict` that our React app can connect to securely.
77
+ demo = gr.Interface(
78
+ fn=diagnose_api,
79
+ inputs=[
80
+ gr.Textbox(lines=5, label="Patient History (String)", placeholder="Age, symptoms, past medical history..."),
81
+ gr.Textbox(lines=5, label="Examination Findings (String)", placeholder="Vitals, systemic exam..."),
82
+ gr.Image(type="pil", label="Diagnostic Scan / Image"),
83
+ gr.Audio(type="filepath", label="Optional Dictation Audio", visible=False) # Qwen-VL does not naturally support audio, handling externally or ignoring
84
+ ],
85
+ outputs=gr.Markdown(label="Clinical Report Output"),
86
+ title="Multi-Modal Diagnostic Co-Pilot API (Trained via Kaggle)",
87
+ description="This Space hosts the fine-tuned medical vision-language model for the Diagnostic Co-Pilot ecosystem."
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch(share=False) # Will be automatically launched by HF Spaces without share=True
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ transformers>=4.40.0
3
+ accelerate
4
+ peft
5
+ bitsandbytes
6
+ trl
7
+ datasets
8
+ gradio>=4.0.0
9
+ Pillow
train_multimodal.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, TrainingArguments
3
+ from peft import LoraConfig, get_peft_model
4
+ from datasets import load_dataset
5
+ from trl import SFTTrainer
6
+
7
+ # 1. Configuration for Kaggle/HuggingFace Fine-Tuning
8
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" # Small, highly capable multimodal model perfect for medical VQA
9
+ DATASET_ID = "flaviagiammarino/vqa-rad" # Example Medical VQA dataset (Radiology)
10
+ OUTPUT_DIR = "./med-qwen-vl-adapter"
11
+
12
+ def main():
13
+ print(f"Loading processor and model: {MODEL_ID}")
14
+
15
+ # Load processor and model with memory-efficient 4-bit quantization
16
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
17
+
18
+ model = AutoModelForVision2Seq.from_pretrained(
19
+ MODEL_ID,
20
+ device_map="auto",
21
+ torch_dtype=torch.float16,
22
+ low_cpu_mem_usage=True,
23
+ )
24
+
25
+ # Apply LoRA (Low-Rank Adaptation)
26
+ print("Applying LoRA parameters...")
27
+ lora_config = LoraConfig(
28
+ r=16,
29
+ lora_alpha=32,
30
+ lora_dropout=0.05,
31
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Attention layers
32
+ bias="none",
33
+ )
34
+ model = get_peft_model(model, lora_config)
35
+ model.print_trainable_parameters()
36
+
37
+ # Load and format the dataset
38
+ print(f"Loading dataset: {DATASET_ID}")
39
+ dataset = load_dataset(DATASET_ID, split="train[:50%]") # Use subset for demonstration
40
+
41
+ def format_data(example):
42
+ # We need to format the inputs as required by the specific model
43
+ # For Qwen2-VL:
44
+ messages = [
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {"type": "image"},
49
+ {"type": "text", "text": example["question"]}
50
+ ]
51
+ },
52
+ {
53
+ "role": "assistant",
54
+ "content": [
55
+ {"type": "text", "text": example["answer"]}
56
+ ]
57
+ }
58
+ ]
59
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
60
+ return {"text": text, "image": example["image"]}
61
+
62
+ formatted_dataset = dataset.map(format_data, remove_columns=dataset.column_names)
63
+
64
+ # Setup Training Arguments
65
+ training_args = TrainingArguments(
66
+ output_dir=OUTPUT_DIR,
67
+ per_device_train_batch_size=2,
68
+ gradient_accumulation_steps=4,
69
+ learning_rate=2e-4,
70
+ logging_steps=10,
71
+ max_steps=100, # Set low for quick Kaggle demonstration
72
+ save_strategy="steps",
73
+ save_steps=50,
74
+ fp16=True,
75
+ optim="paged_adamw_8bit",
76
+ remove_unused_columns=False,
77
+ report_to="none" # Disable wandb for seamless Kaggle runs
78
+ )
79
+
80
+ # Custom Data Collator for Vision-Language Models
81
+ def collate_fn(examples):
82
+ texts = [ex["text"] for ex in examples]
83
+ images = [ex["image"] for ex in examples]
84
+
85
+ batch = processor(
86
+ text=texts,
87
+ images=images,
88
+ padding=True,
89
+ return_tensors="pt"
90
+ )
91
+ # Labels are the same as input_ids for standard causal LM training
92
+ batch["labels"] = batch["input_ids"].clone()
93
+ return batch
94
+
95
+ # Train using TRL's SFT Trainer
96
+ print("Starting fine-tuning...")
97
+ trainer = SFTTrainer(
98
+ model=model,
99
+ args=training_args,
100
+ train_dataset=formatted_dataset,
101
+ data_collator=collate_fn,
102
+ dataset_text_field="text" # SFTTrainer requires this, though we use a custom collator
103
+ )
104
+
105
+ trainer.train()
106
+
107
+ # Save the adapter
108
+ print(f"Saving fine-tuned adapter to {OUTPUT_DIR}")
109
+ trainer.save_model(OUTPUT_DIR)
110
+ processor.save_pretrained(OUTPUT_DIR)
111
+
112
+ print("Done! You can now merge this adapter or upload it directly to the Hugging Face Hub (e.g. via model.push_to_hub())")
113
+
114
+ if __name__ == "__main__":
115
+ main()