admin08077 commited on
Commit
aab002d
·
verified ·
1 Parent(s): 4060206

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from flask import Flask, request, jsonify
4
+ from transformers import pipeline
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
+
9
+ # Import for image generation
10
+ from diffusers import AutoPipelineForText2Image
11
+
12
+ app = Flask(__name__)
13
+
14
+ # --- Configuration ---
15
+ GEMMA_MODEL_ID = "google/gemma-4-E2B-it"
16
+ IMAGE_GEN_MODEL_ID = "stabilityai/sd-turbo" # A fast, small Stable Diffusion model for demonstration
17
+ MAX_NEW_TOKENS = 200 # Adjust as needed for Gemma 4 response length
18
+ IMAGE_SIZE = (512, 512) # For generated images
19
+
20
+ # Determine device for models
21
+ # For a CPU-focused Dockerfile, this will default to CPU (-1 or "cpu")
22
+ if os.environ.get("USE_GPU", "false").lower() == "true" and os.getenv("CUDA_VISIBLE_DEVICES", "") != "":
23
+ device = 0 # Use the first GPU
24
+ torch_device_name = "cuda"
25
+ else:
26
+ device = -1 # Use CPU
27
+ torch_device_name = "cpu"
28
+
29
+ # --- Model Loading ---
30
+ gemma_pipeline = None
31
+ image_gen_pipeline = None
32
+
33
+ try:
34
+ print(f"Loading Gemma 4 multimodal model: {GEMMA_MODEL_ID} on device {torch_device_name} (pipeline device {device})...")
35
+ gemma_pipeline = pipeline("any-to-any", model=GEMMA_MODEL_ID, device=device)
36
+ print("Gemma 4 model loaded successfully.")
37
+ except Exception as e:
38
+ print(f"Error loading Gemma 4 model: {e}")
39
+
40
+ try:
41
+ print(f"Loading Image Generation model: {IMAGE_GEN_MODEL_ID} on device {torch_device_name}...")
42
+ image_gen_pipeline = AutoPipelineForText2Image.from_pretrained(IMAGE_GEN_MODEL_ID).to(torch_device_name)
43
+ # Only enable xformers if on GPU
44
+ if torch_device_name == "cuda":
45
+ try:
46
+ # Note: xformers might require a specific CUDA version or manual installation.
47
+ # If this line causes issues, comment it out.
48
+ image_gen_pipeline.enable_xformers_memory_efficient_attention() # Optional: for memory efficiency on GPU
49
+ print("xFormers enabled for image generation.")
50
+ except ImportError:
51
+ print("xFormers not installed or not available, skipping memory efficient attention.")
52
+ print("Image Generation model loaded successfully.")
53
+ except Exception as e:
54
+ print(f"Error loading Image Generation model: {e}")
55
+
56
+ # --- Helper Functions ---
57
+ def encode_image_to_base64(image: Image.Image) -> str:
58
+ buffered = io.BytesIO()
59
+ image.save(buffered, format="PNG")
60
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
61
+
62
+ # --- API Endpoints ---
63
+ @app.route('/')
64
+ def home():
65
+ return "Multimodal AI (Gemma 4) and Image Generation API is running. Use /gemma-predict or /generate-image."
66
+
67
+ @app.route('/gemma-predict', methods=['POST'])
68
+ def gemma_predict():
69
+ """
70
+ Endpoint for Gemma 4 multimodal text generation (image + text -> text).
71
+ """
72
+ if gemma_pipeline is None:
73
+ return jsonify({"error": "Gemma 4 model not loaded. Please check server logs."}), 503
74
+
75
+ try:
76
+ data = request.json
77
+ if not data:
78
+ return jsonify({"error": "No JSON data provided"}), 400
79
+
80
+ image_base64 = data.get('image_base64')
81
+ text_prompt = data.get('text_prompt', '')
82
+
83
+ if not image_base64 and not text_prompt:
84
+ return jsonify({"error": "At least 'image_base64' or 'text_prompt' must be provided"}), 400
85
+
86
+ messages = []
87
+
88
+ if image_base64:
89
+ try:
90
+ image_bytes = base64.b64decode(image_base64)
91
+ image = Image.open(io.BytesIO(image_bytes))
92
+ messages.append({
93
+ "type": "image",
94
+ "image": image,
95
+ })
96
+ except Exception as e:
97
+ return jsonify({"error": f"Invalid image_base64 provided: {e}"}), 400
98
+
99
+ if text_prompt:
100
+ messages.append({
101
+ "type": "text",
102
+ "text": text_prompt,
103
+ })
104
+
105
+ if not messages:
106
+ return jsonify({"error": "No valid input (image or text) provided for Gemma."}), 400
107
+
108
+ full_messages = [
109
+ {
110
+ "role": "user",
111
+ "content": messages,
112
+ }
113
+ ]
114
+
115
+ output = gemma_pipeline(full_messages, max_new_tokens=MAX_NEW_TOKENS, return_full_text=False)
116
+
117
+ if output and len(output) > 0 and "generated_text" in output[0]:
118
+ return jsonify({"prediction": output[0]["generated_text"]})
119
+ else:
120
+ return jsonify({"error": "Gemma 4 model did not return generated text."}), 500
121
+
122
+ except Exception as e:
123
+ print(f"Error during Gemma 4 prediction: {e}")
124
+ return jsonify({"error": f"An error occurred during Gemma 4 prediction: {str(e)}"}), 500
125
+
126
+ @app.route('/generate-image', methods=['POST'])
127
+ def generate_image():
128
+ """
129
+ Endpoint for text-to-image generation.
130
+ """
131
+ if image_gen_pipeline is None:
132
+ return jsonify({"error": "Image generation model not loaded. Please check server logs."}), 503
133
+
134
+ try:
135
+ data = request.json
136
+ if not data:
137
+ return jsonify({"error": "No JSON data provided"}), 400
138
+
139
+ prompt = data.get('prompt')
140
+ if not prompt:
141
+ return jsonify({"error": "Missing 'prompt' for image generation."}), 400
142
+
143
+ # Generate image
144
+ # You can add more parameters here like num_inference_steps, guidance_scale
145
+ generated_image = image_gen_pipeline(prompt).images[0]
146
+
147
+ # Encode the generated image to base64
148
+ image_base64 = encode_image_to_base64(generated_image)
149
+
150
+ return jsonify({"image_base64": image_base64, "prompt": prompt})
151
+
152
+ except Exception as e:
153
+ print(f"Error during image generation: {e}")
154
+ return jsonify({"error": f"An error occurred during image generation: {str(e)}"}), 500
155
+
156
+ @app.route('/status', methods=['GET'])
157
+ def status():
158
+ """
159
+ Checks the status of both AI models.
160
+ """
161
+ gemma_status = "ready" if gemma_pipeline else "not_loaded"
162
+ image_gen_status = "ready" if image_gen_pipeline else "not_loaded"
163
+ return jsonify({
164
+ "gemma_4_model_id": GEMMA_MODEL_ID,
165
+ "gemma_4_status": gemma_status,
166
+ "image_gen_model_id": IMAGE_GEN_MODEL_ID,
167
+ "image_gen_status": image_gen_status,
168
+ "device_used": torch_device_name
169
+ })
170
+
171
+ # --- Main Execution ---
172
+ if __name__ == '__main__':
173
+ app.run(host='0.0.0.0', port=5000, debug=True)