File size: 10,899 Bytes
4372989
0ba5d24
 
 
 
 
 
 
4372989
0ba5d24
4372989
915a0fa
0ba5d24
4372989
0ba5d24
 
 
4372989
0ba5d24
4372989
0ba5d24
 
 
 
 
 
 
 
 
4372989
0ba5d24
7e6e9e1
0ba5d24
 
 
 
7e6e9e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba5d24
 
 
 
 
 
 
 
 
 
 
 
7e6e9e1
 
 
 
 
 
 
0ba5d24
7e6e9e1
 
 
0ba5d24
 
7e6e9e1
 
0ba5d24
 
4372989
0ba5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37156b
0ba5d24
 
 
 
4372989
0ba5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37156b
0ba5d24
 
 
 
 
 
 
 
 
 
d37156b
 
 
 
0ba5d24
d37156b
0ba5d24
 
 
 
 
 
d37156b
0ba5d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37156b
0ba5d24
 
 
 
4372989
0ba5d24
 
5045a37
 
 
 
 
 
0ba5d24
 
 
 
5045a37
 
 
 
0ba5d24
 
 
 
d37156b
5045a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ba5d24
 
5045a37
 
4372989
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from flask import Blueprint, request, jsonify
import base64
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
from mlc_llm import ChatModule
import threading
import os

# Create a Blueprint for API routes
app.register_blueprint(api_bp, url_prefix='/api')

# Global instances for models (will be initialized in app.py)
clip_processor = None
clip_model = None
mlc_chat_module = None
mlc_lock = threading.Lock()

@api_bp.route('/health')
def health_check():
    llm_status = "loaded" if mlc_chat_module else "not loaded (check logs)"
    clip_status = "loaded" if clip_model else "not loaded (check logs)"
    return jsonify({
        "status": "Quantum-Enhanced WAN 2.1 Backend is running!",
        "llm_status": llm_status,
        "clip_status": clip_status
    })

@api_bp.route('/embed_image', methods=['POST'])
def embed_image():
    """Handle image embedding requests"""
    if clip_processor is None or clip_model is None:
        return jsonify({"error": "CLIP model not loaded. Check server logs for details."}), 500

    try:
        data = request.get_json()
        if not data:
            return jsonify({"error": "Invalid JSON data"}), 400
            
        image_data_url = data.get('image') or data.get('image_url') or data.get('image_data')
        
        if not image_data_url:
            return jsonify({"error": "No image data provided. Expected 'image', 'image_url', or 'image_data' field."}), 400

        # Handle data URL format
        if ',' in image_data_url:
            header, encoded = image_data_url.split(",", 1)
        else:
            # Assume it's raw base64
            encoded = image_data_url
        
        # Decode and process image
        image_bytes = base64.b64decode(encoded)
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        inputs = clip_processor(images=image, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        with torch.no_grad():
            image_features = clip_model.get_image_features(**inputs)
        
        # Normalize embeddings and convert to list for JSON serialization
        image_embeddings = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
        embeddings_list = image_embeddings.squeeze().cpu().tolist()
        
        return jsonify({
            "embeddings": embeddings_list,
            "shape": image_embeddings.shape,
            "success": True
        }), 200

    except ValueError as ve:
        print(f"Value error embedding image: {ve}")
        return jsonify({"error": f"Invalid image data format: {str(ve)}"}), 400
    except Exception as e:
        print(f"Error embedding image: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": f"Failed to embed image: {str(e)}"}), 500

@api_bp.route('/chat/completions', methods=['POST'])
def chat_completions_endpoint():
    if mlc_chat_module is None:
        return jsonify({"error": "LLM model not loaded. Check server logs for details."}), 500

    data = request.get_json()
    prompt = data.get("prompt")
    system_message = data.get("system_message", "You are a creative AI assistant for video generation.")
    
    if not prompt:
        return jsonify({"error": "Prompt is required"}), 400

    try:
        full_prompt = f"{system_message}\nUser: {prompt}"

        with mlc_lock:
            mlc_chat_module.reset_chat() 
            response = mlc_chat_module.generate(full_prompt)
        
        return jsonify({"completion": response}), 200
    except Exception as e:
        print(f"Error getting chat completion: {e}")
        return jsonify({"error": f"Failed to get chat completion: {str(e)}"}), 500

@api_bp.route('/generate_frame_guidance', methods=['POST'])
def generate_frame_guidance():
    # This endpoint provides LLM guidance for the frontend's quantum diffusion.
    # It does NOT generate the image itself.
    if mlc_chat_module is None or clip_processor is None or clip_model is None:
        return jsonify({"error": "One or more AI models not loaded. Check server logs for details."}), 500

    data = request.get_json()
    image_data_url = data.get('image') # The current frame from the frontend
    prompt = data.get('prompt', 'Quantum interpolation')
    influence = data.get('influence', 5) # 0-100
    entanglement_depth = data.get('depth', 16) # For LLM to consider
    frame_number = data.get('frame_number', 0)
    
    if not image_data_url:
        return jsonify({"error": "No image data provided"}), 400

    try:
        # 1. Get CLIP embeddings for the current frame
        header, encoded = image_data_url.split(",", 1)
        image_bytes = base64.b64decode(encoded)
        input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        clip_inputs = clip_processor(images=input_image, return_tensors="pt")
        if torch.cuda.is_available():
            clip_inputs = {k: v.to("cuda") for k, v in clip_inputs.items()}

        with torch.no_grad():
            image_features = clip_model.get_image_features(**clip_inputs)
        image_embeddings_np = image_features.squeeze().cpu().numpy()
        embedding_snippet = ", ".join([f"{x:.4f}" for x in image_embeddings_np[:10]])

        # 2. Use LLM to generate guidance for the next quantum diffusion step
        import math
        llm_prompt = (
            f"You are an AI video director for a quantum diffusion system. Your task is to guide the transformation "
            f"of a video frame based on quantum principles and user input. "
            f"Given the current frame's visual context (CLIP features: [{embedding_snippet}...]), "
            f"the user's creative prompt: '{prompt}', "
            f"and the quantum settings (Quantum Influence: {influence}%, Entanglement Depth: {entanglement_depth} layers), "
            f"describe *precisely* how the quantum diffusion effect should transform the current frame into frame {frame_number + 1}. "
            f"Think of these transformations as manipulating a quantum state that manifests visually. "
            f"Higher influence and depth should lead to more pronounced, chaotic, or surreal quantum effects. "
            f"Focus on quantifiable visual parameters, including: "
            f"color shifts (e.g., 'shift red by +{round(influence/5)}', 'hue rotate {round(influence*1.5)}deg'), "
            f"blur (e.g., 'apply gaussian blur radius {max(1, round(influence/10))}'), "
            f"glitch/distortion (e.g., 'pixel displacement x-axis random {max(5, round(influence/5))}px', 'chromatic aberration offset {max(1, round(influence/20))}'), "
            f"zoom/pan (e.g., 'zoom in {1.00 + influence/2000}x, pan right {round(influence/10)}px'), "
            f"pattern overlay (e.g., 'overlay subtle static pattern opacity {influence/200}'), "
            f"motion blur (e.g., 'apply motion blur strength {round(entanglement_depth/2)}'), "
            f"bloom (e.g., 'add bloom strength {influence/100}'), "
            f"noise (e.g., 'add noise amount {influence/50}'), "
            f"vignette (e.g., 'add vignette strength {influence/200}'), "
            f"or specific quantum-themed visual cues (e.g., 'ripple effect', 'add subtle scanlines opacity {influence/200}', 'invert colors'). "
            f"Combine these to create a dynamic, quantum-like visual evolution. Ensure the intensity of effects scales with Influence and Depth. "
            f"Be concise and output only the transformation instructions. "
            f"Example: 'shift blue by +{round(influence/5)}, apply motion blur strength {round(entanglement_depth/2)}, zoom {1.00 + influence/2000}x, add subtle scanlines opacity {influence/200}'.\n"
            f"Transformation Instructions for frame {frame_number + 1}:"
        )
        
        llm_guidance = ""
        try:
            with mlc_lock:
                mlc_chat_module.reset_chat()
                llm_guidance = mlc_chat_module.generate(llm_prompt)
        except Exception as llm_e:
            print(f"LLM guidance generation failed: {llm_e}. Using fallback guidance.")
            llm_guidance = f"apply subtle glitch effect, shift colors slightly based on quantum influence {influence}%."

        print(f"LLM Guidance: {llm_guidance}")

        return jsonify({
            "guidance": llm_guidance,
            "log": (f"Backend provided guidance for frame {frame_number + 1} based on prompt: '{prompt[:50]}...', "
                    f"influence: {influence}, depth: {entanglement_depth}. LLM guidance: '{llm_guidance[:50]}...'.")
        }), 200
    except Exception as e:
        print(f"Error generating frame guidance: {e}")
        return jsonify({"error": f"Failed to generate frame guidance: {str(e)}"}), 500

@api_bp.route('/upload', methods=['POST'])
def upload_file():
    try:
        # Check if it's a multipart form upload (FormData)
        if 'file' in request.files:
            file = request.files['file']
            if file.filename == '':
                return jsonify({"error": "No selected file"}), 400
            
            # Read the image file
            img_bytes = file.read()
            # Convert to base64 for frontend use
            img_base64 = base64.b64encode(img_bytes).decode('utf-8')
            
            # Determine mime type
            content_type = file.content_type or 'image/jpeg'
            img_data_url = f"data:{content_type};base64,{img_base64}"
            
            return jsonify({
                "message": "File uploaded successfully",
                "image_url": img_data_url
            }), 200
        
        # Check if it's JSON with base64 data
        elif request.is_json:
            data = request.get_json()
            image_data = data.get('image') or data.get('image_url') or data.get('image_data')
            
            if not image_data:
                return jsonify({"error": "No image data provided"}), 400
            
            # If already a data URL, return as-is
            if image_data.startswith('data:image'):
                return jsonify({
                    "message": "Image data received",
                    "image_url": image_data
                }), 200
            
            # If base64 without header, add it
            img_data_url = f"data:image/jpeg;base64,{image_data}"
            return jsonify({
                "message": "Image data processed",
                "image_url": img_data_url
            }), 200
        
        else:
            return jsonify({"error": "Invalid request format. Send either FormData with 'file' or JSON with 'image' field"}), 400
            
    except Exception as e:
        print(f"Error uploading file: {e}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": f"Failed to upload file: {str(e)}"}), 500