File size: 5,068 Bytes
e6b3993
 
94b7e53
6a21530
e6b3993
94b7e53
e6b3993
 
94b7e53
e6b3993
 
 
94b7e53
6a21530
 
e6b3993
 
6a21530
 
 
 
 
e6b3993
 
 
94b7e53
6a21530
 
 
 
 
94b7e53
e6b3993
 
 
 
6a21530
5d2c35f
e6b3993
 
 
6a21530
83ee74c
6a21530
 
83ee74c
6a21530
 
 
 
 
83ee74c
6a21530
 
 
 
83ee74c
 
5d2c35f
83ee74c
 
 
 
 
 
 
 
 
e6b3993
 
 
83ee74c
6a21530
83ee74c
6a21530
 
 
 
83ee74c
 
 
 
 
 
 
 
 
e6b3993
 
6a21530
e6b3993
 
 
 
 
94b7e53
 
 
e6b3993
 
 
 
 
 
 
 
 
 
 
6a21530
e6b3993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a21530
 
e6b3993
 
 
94b7e53
6a21530
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from PIL import Image

# Model ID
MODEL_ID = "0llheaven/Llama-3.2-11B-Vision-Radiology-mini"

# Load tokenizer and processor
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Load the model with reduced precision and memory optimizations
print("Loading model with memory optimizations...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,  # Use half precision
    device_map="auto",          # Let the library decide how to map the model
    low_cpu_mem_usage=True,     # Optimize CPU memory usage
    offload_folder="offload",   # Offload weights to disk if needed
    offload_state_dict=True,    # Enable state dict offloading
    trust_remote_code=True,
)
print("Model loaded!")

# Clear CUDA cache after loading
if torch.cuda.is_available():
    torch.cuda.empty_cache()

def generate_response(image_file, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
    try:
        # Process image if provided
        if image_file is not None:
            image = Image.open(image_file).convert('RGB')
            
            # Process inputs
            inputs = processor(
                text=prompt,
                images=image,
                return_tensors="pt"
            )
            
            # Move inputs to the same device as model
            inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
            
            # For safer generation, extract only what's needed
            input_ids = inputs.pop("input_ids", None)
            attention_mask = inputs.pop("attention_mask", None)
            
            # Generate response with conservative memory settings
            with torch.no_grad():
                # Clear cache before generation
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                outputs = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True
                )
            
            # Decode and return the response
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
        else:
            # Text-only input
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            
            # Generate response with conservative memory settings
            with torch.no_grad():
                # Clear cache before generation
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True
                )
            
            # Decode and return the response
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Remove the input prompt from the response if present
        if response.startswith(prompt):
            response = response[len(prompt):].strip()
            
        return response
    
    except Exception as e:
        return f"Error: {str(e)}"

# Define the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Llama-3.2-11B Vision Radiology Model")
    gr.Markdown("Upload a radiology image (X-ray, CT, MRI, etc.) and ask questions about it.")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Upload Radiology Image")
            prompt_input = gr.Textbox(label="Question or Prompt", placeholder="Describe what you see in this image and identify any abnormalities.")
            
            with gr.Row():
                max_tokens = gr.Slider(minimum=16, maximum=512, value=256, step=8, label="Max New Tokens")
                temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
                top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p")
                
            submit_btn = gr.Button("Generate Response")
        
        with gr.Column():
            output = gr.Textbox(label="Model Response", lines=15)
    
    submit_btn.click(
        generate_response,
        inputs=[image_input, prompt_input, max_tokens, temperature, top_p],
        outputs=[output]
    )
    
    gr.Examples(
        [
            ["sample_xray.jpg", "What abnormalities do you see in this X-ray?"],
            ["sample_ct.jpg", "Describe this image and any findings."],
        ],
        inputs=[image_input, prompt_input],
    )

# Reduce maximum allowed concurrent users to conserve memory
demo.launch(max_threads=1)