Spaces:
Sleeping
Sleeping
File size: 5,066 Bytes
8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 11368a6 8af47a3 5dd7eca 8af47a3 11368a6 8af47a3 d8aef83 8af47a3 a5b3750 11368a6 8af47a3 eb170be | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import json
import os
DEFAULT_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
DEFAULT_ADAPTER_ID = "hssling/cardioai-adapter"
CONFIG_PATH = "model_config.json"
def load_runtime_config():
config = {
"base_model": os.environ.get("BASE_MODEL_ID", DEFAULT_MODEL_ID),
"adapter_repo": os.environ.get("ADAPTER_REPO_ID", DEFAULT_ADAPTER_ID),
"adapter_revision": os.environ.get("ADAPTER_REVISION", "main")
}
if os.path.exists(CONFIG_PATH):
try:
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
disk_cfg = json.load(f)
config["base_model"] = disk_cfg.get("base_model", config["base_model"])
config["adapter_repo"] = disk_cfg.get("adapter_repo", config["adapter_repo"])
config["adapter_revision"] = disk_cfg.get("adapter_revision", config["adapter_revision"])
except Exception as e:
print(f"Failed to read {CONFIG_PATH}; falling back to defaults. Error: {e}")
return config
cfg = load_runtime_config()
MODEL_ID = cfg["base_model"]
ADAPTER_ID = cfg["adapter_repo"]
ADAPTER_REV = cfg["adapter_revision"]
print("Starting App Engine...")
os.makedirs("/tmp/offload", exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
model_kwargs = {
"pretrained_model_name_or_path": MODEL_ID,
"device_map": "auto",
"low_cpu_mem_usage": True,
"offload_folder": "/tmp/offload"
}
if device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
else:
# CPU space: keep dtype low to reduce memory footprint.
model_kwargs["torch_dtype"] = torch.float16
model = Qwen2VLForConditionalGeneration.from_pretrained(**model_kwargs)
if ADAPTER_ID:
print(f"Loading custom fine-tuned LoRA weights: {ADAPTER_ID}@{ADAPTER_REV}")
try:
model = PeftModel.from_pretrained(
model,
ADAPTER_ID,
revision=ADAPTER_REV,
is_trainable=False
)
print("Adapter load successful.")
except Exception as e:
print(f"Failed to load adapter; serving base model instead. Error: {e}")
def diagnose_ecg(image: Image.Image = None, temp: float = 0.4, max_tokens: int = 768):
try:
if image is None:
return json.dumps({"error": "No image provided."})
system_prompt = (
"You are CardioAI, an ECG interpretation engine. "
"Always analyze the provided ECG image directly. "
"Do not provide generic AI disclaimers. "
"Return concise clinical content only."
)
user_prompt = (
"Interpret this ECG image and return exactly these sections: "
"1) Impression, 2) Rhythm, 3) Rate, 4) ST-T Findings, 5) Urgency. "
"If image quality is insufficient, write 'Non-diagnostic ECG image quality' in Impression."
)
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_prompt}
]
}
]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text_input],
images=[image],
padding=True,
return_tensors="pt"
)
model_device = model.device if hasattr(model, "device") else torch.device(device)
inputs = {k: v.to(model_device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=int(max_tokens), temperature=float(temp), top_p=0.9, do_sample=True)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return output_text
except Exception as e:
return f"Error: {str(e)}"
demo = gr.Interface(
fn=diagnose_ecg,
inputs=[
gr.Image(type="pil", label="ECG Image Scan"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="Temperature"),
gr.Slider(minimum=128, maximum=1536, value=768, step=128, label="Max Tokens")
],
outputs=gr.Markdown(label="Clinical Report Output"),
title="CardioAI Inference API",
description="Fine-tuned Medical LLM for Electrocardiogram (ECG) Tracings."
)
if __name__ == "__main__":
demo.launch()
|