File size: 6,263 Bytes
8fb7337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d30672f
8fb7337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
067b592
 
8fb7337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from flask import Flask, request, jsonify
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image
import base64
import io
import os
from werkzeug.utils import secure_filename

app = Flask(__name__)

# Global variables for model and processor
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
processor = None
model = None

def load_model():
    """Load the model and processor globally"""
    global processor, model
    if processor is None or model is None:
        print("Loading model and processor...")
        processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
        model = AutoModelForVision2Seq.from_pretrained(
            "HuggingFaceTB/SmolVLM-500M-Instruct",
            torch_dtype=torch.bfloat16,
            _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager",
        ).to(DEVICE)
        print("Model loaded successfully!")

def process_image(image_input):
    """Process image input which can be a URL, base64 string, or file path"""
    try:
        if image_input.startswith('http'):
            # Load image from URL
            return load_image(image_input)
        elif image_input.startswith('data:image'):
            # Handle base64 encoded image
            # Remove the data URL prefix
            image_data = image_input.split(',')[1]
            image_bytes = base64.b64decode(image_data)
            image = Image.open(io.BytesIO(image_bytes))
            return image
        else:
            # Assume it's a file path
            return load_image(image_input)
    except Exception as e:
        raise ValueError(f"Error processing image: {str(e)}")

@app.route('/', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({
        "status": "healthy",
        "device": DEVICE,
        "model_loaded": model is not None
    })

@app.route('/chat', methods=['POST'])
def chat():
    """Main chat endpoint that accepts messages array"""
    try:
        # Load model if not already loaded
        load_model()
        
        # Get request data
        data = request.get_json()
        
        if not data or 'messages' not in data:
            return jsonify({
                "error": "Missing 'messages' field in request body"
            }), 400
        
        messages = data['messages']
        
        if not isinstance(messages, list) or len(messages) == 0:
            return jsonify({
                "error": "Messages must be a non-empty array"
            }), 400
        
        # Process the last user message to extract image and text
        last_message = messages[-1]
        
        if last_message.get('role') != 'user':
            return jsonify({
                "error": "Last message must be from user"
            }), 400
        
        content = last_message.get('content', [])
        
        if not isinstance(content, list):
            return jsonify({
                "error": "Content must be an array"
            }), 400
        
        # Extract image and text from content
        image = None
        text = ""
        
        for item in content:
            if item.get('type') == 'image_url' and 'image_url' in item and 'url' in item['image_url']:
                image = process_image(item['image_url']['url'])
            elif item.get('type') == 'text':
                text = item.get('text', '')
        
        if not image:
            return jsonify({
                "error": "No image found in the message"
            }), 400
        
        if not text:
            return jsonify({
                "error": "No text found in the message"
            }), 400
        
        # Prepare inputs for the model
        model_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": text}
                ]
            }
        ]
        
        prompt = processor.apply_chat_template(model_messages, add_generation_prompt=True)
        inputs = processor(text=prompt, images=[image], return_tensors="pt")
        inputs = inputs.to(DEVICE)
        
        # Generate response
        generated_ids = model.generate(**inputs, max_new_tokens=500)
        generated_texts = processor.batch_decode(
            generated_ids,
            skip_special_tokens=True,
        )
        
        # Extract the assistant's response
        response_text = generated_texts[0]
        
        # Find the assistant's response in the generated text
        if "Assistant:" in response_text:
            response_text = response_text.split("Assistant:")[-1].strip()
        
        return jsonify({
            "response": response_text,
            "model": "SmolVLM-500M-Instruct",
            "device": DEVICE
        })
        
    except Exception as e:
        return jsonify({
            "error": f"An error occurred: {str(e)}"
        }), 500

@app.route('/upload', methods=['POST'])
def upload_image():
    """Endpoint to upload an image file"""
    try:
        if 'image' not in request.files:
            return jsonify({
                "error": "No image file provided"
            }), 400
        
        file = request.files['image']
        if file.filename == '':
            return jsonify({
                "error": "No file selected"
            }), 400
        
        # Save the uploaded file temporarily
        filename = secure_filename(file.filename)
        filepath = os.path.join('/tmp', filename)
        file.save(filepath)
        
        # Convert to base64 for easy handling
        with open(filepath, 'rb') as img_file:
            img_data = base64.b64encode(img_file.read()).decode('utf-8')
        
        # Clean up temporary file
        os.remove(filepath)
        
        return jsonify({
            "image_data": f"data:image/jpeg;base64,{img_data}",
            "filename": filename
        })
        
    except Exception as e:
        return jsonify({
            "error": f"An error occurred: {str(e)}"
        }), 500

if __name__ == '__main__':
    load_model()
    app.run(host='0.0.0.0', port=7860, debug=False)