File size: 10,899 Bytes
4372989 0ba5d24 4372989 0ba5d24 4372989 915a0fa 0ba5d24 4372989 0ba5d24 4372989 0ba5d24 4372989 0ba5d24 4372989 0ba5d24 7e6e9e1 0ba5d24 7e6e9e1 0ba5d24 7e6e9e1 0ba5d24 7e6e9e1 0ba5d24 7e6e9e1 0ba5d24 4372989 0ba5d24 d37156b 0ba5d24 4372989 0ba5d24 d37156b 0ba5d24 d37156b 0ba5d24 d37156b 0ba5d24 d37156b 0ba5d24 d37156b 0ba5d24 4372989 0ba5d24 5045a37 0ba5d24 5045a37 0ba5d24 d37156b 5045a37 0ba5d24 5045a37 4372989 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
from flask import Blueprint, request, jsonify
import base64
from PIL import Image
import io
import torch
from transformers import CLIPProcessor, CLIPModel
from mlc_llm import ChatModule
import threading
import os
# Create a Blueprint for API routes
app.register_blueprint(api_bp, url_prefix='/api')
# Global instances for models (will be initialized in app.py)
clip_processor = None
clip_model = None
mlc_chat_module = None
mlc_lock = threading.Lock()
@api_bp.route('/health')
def health_check():
llm_status = "loaded" if mlc_chat_module else "not loaded (check logs)"
clip_status = "loaded" if clip_model else "not loaded (check logs)"
return jsonify({
"status": "Quantum-Enhanced WAN 2.1 Backend is running!",
"llm_status": llm_status,
"clip_status": clip_status
})
@api_bp.route('/embed_image', methods=['POST'])
def embed_image():
"""Handle image embedding requests"""
if clip_processor is None or clip_model is None:
return jsonify({"error": "CLIP model not loaded. Check server logs for details."}), 500
try:
data = request.get_json()
if not data:
return jsonify({"error": "Invalid JSON data"}), 400
image_data_url = data.get('image') or data.get('image_url') or data.get('image_data')
if not image_data_url:
return jsonify({"error": "No image data provided. Expected 'image', 'image_url', or 'image_data' field."}), 400
# Handle data URL format
if ',' in image_data_url:
header, encoded = image_data_url.split(",", 1)
else:
# Assume it's raw base64
encoded = image_data_url
# Decode and process image
image_bytes = base64.b64decode(encoded)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
# Normalize embeddings and convert to list for JSON serialization
image_embeddings = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
embeddings_list = image_embeddings.squeeze().cpu().tolist()
return jsonify({
"embeddings": embeddings_list,
"shape": image_embeddings.shape,
"success": True
}), 200
except ValueError as ve:
print(f"Value error embedding image: {ve}")
return jsonify({"error": f"Invalid image data format: {str(ve)}"}), 400
except Exception as e:
print(f"Error embedding image: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": f"Failed to embed image: {str(e)}"}), 500
@api_bp.route('/chat/completions', methods=['POST'])
def chat_completions_endpoint():
if mlc_chat_module is None:
return jsonify({"error": "LLM model not loaded. Check server logs for details."}), 500
data = request.get_json()
prompt = data.get("prompt")
system_message = data.get("system_message", "You are a creative AI assistant for video generation.")
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
try:
full_prompt = f"{system_message}\nUser: {prompt}"
with mlc_lock:
mlc_chat_module.reset_chat()
response = mlc_chat_module.generate(full_prompt)
return jsonify({"completion": response}), 200
except Exception as e:
print(f"Error getting chat completion: {e}")
return jsonify({"error": f"Failed to get chat completion: {str(e)}"}), 500
@api_bp.route('/generate_frame_guidance', methods=['POST'])
def generate_frame_guidance():
# This endpoint provides LLM guidance for the frontend's quantum diffusion.
# It does NOT generate the image itself.
if mlc_chat_module is None or clip_processor is None or clip_model is None:
return jsonify({"error": "One or more AI models not loaded. Check server logs for details."}), 500
data = request.get_json()
image_data_url = data.get('image') # The current frame from the frontend
prompt = data.get('prompt', 'Quantum interpolation')
influence = data.get('influence', 5) # 0-100
entanglement_depth = data.get('depth', 16) # For LLM to consider
frame_number = data.get('frame_number', 0)
if not image_data_url:
return jsonify({"error": "No image data provided"}), 400
try:
# 1. Get CLIP embeddings for the current frame
header, encoded = image_data_url.split(",", 1)
image_bytes = base64.b64decode(encoded)
input_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
clip_inputs = clip_processor(images=input_image, return_tensors="pt")
if torch.cuda.is_available():
clip_inputs = {k: v.to("cuda") for k, v in clip_inputs.items()}
with torch.no_grad():
image_features = clip_model.get_image_features(**clip_inputs)
image_embeddings_np = image_features.squeeze().cpu().numpy()
embedding_snippet = ", ".join([f"{x:.4f}" for x in image_embeddings_np[:10]])
# 2. Use LLM to generate guidance for the next quantum diffusion step
import math
llm_prompt = (
f"You are an AI video director for a quantum diffusion system. Your task is to guide the transformation "
f"of a video frame based on quantum principles and user input. "
f"Given the current frame's visual context (CLIP features: [{embedding_snippet}...]), "
f"the user's creative prompt: '{prompt}', "
f"and the quantum settings (Quantum Influence: {influence}%, Entanglement Depth: {entanglement_depth} layers), "
f"describe *precisely* how the quantum diffusion effect should transform the current frame into frame {frame_number + 1}. "
f"Think of these transformations as manipulating a quantum state that manifests visually. "
f"Higher influence and depth should lead to more pronounced, chaotic, or surreal quantum effects. "
f"Focus on quantifiable visual parameters, including: "
f"color shifts (e.g., 'shift red by +{round(influence/5)}', 'hue rotate {round(influence*1.5)}deg'), "
f"blur (e.g., 'apply gaussian blur radius {max(1, round(influence/10))}'), "
f"glitch/distortion (e.g., 'pixel displacement x-axis random {max(5, round(influence/5))}px', 'chromatic aberration offset {max(1, round(influence/20))}'), "
f"zoom/pan (e.g., 'zoom in {1.00 + influence/2000}x, pan right {round(influence/10)}px'), "
f"pattern overlay (e.g., 'overlay subtle static pattern opacity {influence/200}'), "
f"motion blur (e.g., 'apply motion blur strength {round(entanglement_depth/2)}'), "
f"bloom (e.g., 'add bloom strength {influence/100}'), "
f"noise (e.g., 'add noise amount {influence/50}'), "
f"vignette (e.g., 'add vignette strength {influence/200}'), "
f"or specific quantum-themed visual cues (e.g., 'ripple effect', 'add subtle scanlines opacity {influence/200}', 'invert colors'). "
f"Combine these to create a dynamic, quantum-like visual evolution. Ensure the intensity of effects scales with Influence and Depth. "
f"Be concise and output only the transformation instructions. "
f"Example: 'shift blue by +{round(influence/5)}, apply motion blur strength {round(entanglement_depth/2)}, zoom {1.00 + influence/2000}x, add subtle scanlines opacity {influence/200}'.\n"
f"Transformation Instructions for frame {frame_number + 1}:"
)
llm_guidance = ""
try:
with mlc_lock:
mlc_chat_module.reset_chat()
llm_guidance = mlc_chat_module.generate(llm_prompt)
except Exception as llm_e:
print(f"LLM guidance generation failed: {llm_e}. Using fallback guidance.")
llm_guidance = f"apply subtle glitch effect, shift colors slightly based on quantum influence {influence}%."
print(f"LLM Guidance: {llm_guidance}")
return jsonify({
"guidance": llm_guidance,
"log": (f"Backend provided guidance for frame {frame_number + 1} based on prompt: '{prompt[:50]}...', "
f"influence: {influence}, depth: {entanglement_depth}. LLM guidance: '{llm_guidance[:50]}...'.")
}), 200
except Exception as e:
print(f"Error generating frame guidance: {e}")
return jsonify({"error": f"Failed to generate frame guidance: {str(e)}"}), 500
@api_bp.route('/upload', methods=['POST'])
def upload_file():
try:
# Check if it's a multipart form upload (FormData)
if 'file' in request.files:
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"}), 400
# Read the image file
img_bytes = file.read()
# Convert to base64 for frontend use
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
# Determine mime type
content_type = file.content_type or 'image/jpeg'
img_data_url = f"data:{content_type};base64,{img_base64}"
return jsonify({
"message": "File uploaded successfully",
"image_url": img_data_url
}), 200
# Check if it's JSON with base64 data
elif request.is_json:
data = request.get_json()
image_data = data.get('image') or data.get('image_url') or data.get('image_data')
if not image_data:
return jsonify({"error": "No image data provided"}), 400
# If already a data URL, return as-is
if image_data.startswith('data:image'):
return jsonify({
"message": "Image data received",
"image_url": image_data
}), 200
# If base64 without header, add it
img_data_url = f"data:image/jpeg;base64,{image_data}"
return jsonify({
"message": "Image data processed",
"image_url": img_data_url
}), 200
else:
return jsonify({"error": "Invalid request format. Send either FormData with 'file' or JSON with 'image' field"}), 400
except Exception as e:
print(f"Error uploading file: {e}")
import traceback
traceback.print_exc()
return jsonify({"error": f"Failed to upload file: {str(e)}"}), 500 |