import os import requests from flask import Flask, request, jsonify from transformers import pipeline from PIL import Image import io import base64 # Import for image generation from diffusers import AutoPipelineForText2Image app = Flask(__name__) # --- Configuration --- GEMMA_MODEL_ID = "google/gemma-4-E2B-it" IMAGE_GEN_MODEL_ID = "stabilityai/sd-turbo" # A fast, small Stable Diffusion model for demonstration MAX_NEW_TOKENS = 200 # Adjust as needed for Gemma 4 response length IMAGE_SIZE = (512, 512) # For generated images # Determine device for models # For a CPU-focused Dockerfile, this will default to CPU (-1 or "cpu") if os.environ.get("USE_GPU", "false").lower() == "true" and os.getenv("CUDA_VISIBLE_DEVICES", "") != "": device = 0 # Use the first GPU torch_device_name = "cuda" else: device = -1 # Use CPU torch_device_name = "cpu" # --- Model Loading --- gemma_pipeline = None image_gen_pipeline = None try: print(f"Loading Gemma 4 multimodal model: {GEMMA_MODEL_ID} on device {torch_device_name} (pipeline device {device})...") gemma_pipeline = pipeline("any-to-any", model=GEMMA_MODEL_ID, device=device) print("Gemma 4 model loaded successfully.") except Exception as e: print(f"Error loading Gemma 4 model: {e}") try: print(f"Loading Image Generation model: {IMAGE_GEN_MODEL_ID} on device {torch_device_name}...") image_gen_pipeline = AutoPipelineForText2Image.from_pretrained(IMAGE_GEN_MODEL_ID).to(torch_device_name) # Only enable xformers if on GPU if torch_device_name == "cuda": try: # Note: xformers might require a specific CUDA version or manual installation. # If this line causes issues, comment it out. image_gen_pipeline.enable_xformers_memory_efficient_attention() # Optional: for memory efficiency on GPU print("xFormers enabled for image generation.") except ImportError: print("xFormers not installed or not available, skipping memory efficient attention.") print("Image Generation model loaded successfully.") except Exception as e: print(f"Error loading Image Generation model: {e}") # --- Helper Functions --- def encode_image_to_base64(image: Image.Image) -> str: buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') # --- API Endpoints --- @app.route('/') def home(): return "Multimodal AI (Gemma 4) and Image Generation API is running. Use /gemma-predict or /generate-image." @app.route('/gemma-predict', methods=['POST']) def gemma_predict(): """ Endpoint for Gemma 4 multimodal text generation (image + text -> text). """ if gemma_pipeline is None: return jsonify({"error": "Gemma 4 model not loaded. Please check server logs."}), 503 try: data = request.json if not data: return jsonify({"error": "No JSON data provided"}), 400 image_base64 = data.get('image_base64') text_prompt = data.get('text_prompt', '') if not image_base64 and not text_prompt: return jsonify({"error": "At least 'image_base64' or 'text_prompt' must be provided"}), 400 messages = [] if image_base64: try: image_bytes = base64.b64decode(image_base64) image = Image.open(io.BytesIO(image_bytes)) messages.append({ "type": "image", "image": image, }) except Exception as e: return jsonify({"error": f"Invalid image_base64 provided: {e}"}), 400 if text_prompt: messages.append({ "type": "text", "text": text_prompt, }) if not messages: return jsonify({"error": "No valid input (image or text) provided for Gemma."}), 400 full_messages = [ { "role": "user", "content": messages, } ] output = gemma_pipeline(full_messages, max_new_tokens=MAX_NEW_TOKENS, return_full_text=False) if output and len(output) > 0 and "generated_text" in output[0]: return jsonify({"prediction": output[0]["generated_text"]}) else: return jsonify({"error": "Gemma 4 model did not return generated text."}), 500 except Exception as e: print(f"Error during Gemma 4 prediction: {e}") return jsonify({"error": f"An error occurred during Gemma 4 prediction: {str(e)}"}), 500 @app.route('/generate-image', methods=['POST']) def generate_image(): """ Endpoint for text-to-image generation. """ if image_gen_pipeline is None: return jsonify({"error": "Image generation model not loaded. Please check server logs."}), 503 try: data = request.json if not data: return jsonify({"error": "No JSON data provided"}), 400 prompt = data.get('prompt') if not prompt: return jsonify({"error": "Missing 'prompt' for image generation."}), 400 # Generate image # You can add more parameters here like num_inference_steps, guidance_scale generated_image = image_gen_pipeline(prompt).images[0] # Encode the generated image to base64 image_base64 = encode_image_to_base64(generated_image) return jsonify({"image_base64": image_base64, "prompt": prompt}) except Exception as e: print(f"Error during image generation: {e}") return jsonify({"error": f"An error occurred during image generation: {str(e)}"}), 500 @app.route('/status', methods=['GET']) def status(): """ Checks the status of both AI models. """ gemma_status = "ready" if gemma_pipeline else "not_loaded" image_gen_status = "ready" if image_gen_pipeline else "not_loaded" return jsonify({ "gemma_4_model_id": GEMMA_MODEL_ID, "gemma_4_status": gemma_status, "image_gen_model_id": IMAGE_GEN_MODEL_ID, "image_gen_status": image_gen_status, "device_used": torch_device_name }) # --- Main Execution --- if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True)