wolfofbackstreet commited on
Commit
37d69ab
·
verified ·
1 Parent(s): 3fe1cbc

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +20 -6
  2. app.py +50 -13
Dockerfile CHANGED
@@ -18,15 +18,29 @@ RUN pip install uv
18
  COPY requirements.txt .
19
  RUN uv pip install --system -r requirements.txt
20
 
21
- # Copy application code
22
- COPY app.py .
 
23
 
24
- # Download pre-converted OpenVINO model
 
 
 
 
 
25
  RUN python -c "from optimum.intel.openvino import OVStableDiffusionPipeline; \
26
- OVStableDiffusionPipeline.from_pretrained('rupeshs/hyper-sd-sdxl-1-step-openvino-int8', ov_config={'CACHE_DIR': ''})"
 
 
 
 
 
 
 
27
 
28
- # Expose port 5000
29
- EXPOSE 7860
 
30
 
31
  # Command to run the Flask app
32
  CMD ["python", "app.py"]
 
18
  COPY requirements.txt .
19
  RUN uv pip install --system -r requirements.txt
20
 
21
+ # Create cache directories with write permissions
22
+ RUN mkdir -p /app/cache/huggingface /app/cache/openvino /app/matplotlib_cache /app/openvino_cache \
23
+ && chmod -R 777 /app/cache /app/matplotlib_cache /app/openvino_cache
24
 
25
+ # Set environment variables for cache directories
26
+ ENV HF_HOME=/app/cache/huggingface
27
+ ENV MPLCONFIGDIR=/app/matplotlib_cache
28
+ ENV OPENVINO_TELEMETRY_DIR=/app/openvino_cache
29
+
30
+ # Pre-download base SDXL model
31
  RUN python -c "from optimum.intel.openvino import OVStableDiffusionPipeline; \
32
+ OVStableDiffusionPipeline.from_pretrained('rupeshs/hyper-sd-sdxl-1-step-openvino-int8', ov_config={'CACHE_DIR': '/app/cache/openvino'})"
33
+
34
+ # Pre-download a default LoRA model
35
+ RUN python -c "from diffusers import LoraLoaderMixin; \
36
+ LoraLoaderMixin.download_lora_weights('latent-consistency/lcm-lora-sdxl', cache_dir='/app/cache/huggingface')"
37
+
38
+ # Copy application code
39
+ COPY app.py .
40
 
41
+ # Expose port (default 5000, configurable via PORT env variable)
42
+ ENV PORT=7860
43
+ EXPOSE $PORT
44
 
45
  # Command to run the Flask app
46
  CMD ["python", "app.py"]
app.py CHANGED
@@ -1,36 +1,71 @@
1
  import os
2
  from flask import Flask, request, jsonify, send_file
3
  from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipeline
 
4
  from PIL import Image
5
  import io
6
  import torch
 
 
 
 
 
7
 
8
  app = Flask(__name__)
9
 
10
- # Load the pre-converted OpenVINO SDXL model
11
- model_id = "rupeshs/hyper-sd-sdxl-1-step-openvino-int8"
12
- pipeline = OVStableDiffusionPipeline.from_pretrained(
13
- model_id,
14
- ov_config={"CACHE_DIR": ""},
15
- device="CPU"
16
- )
 
17
 
18
- # Ensure Tiny Auto Encoder is enabled to reduce memory usage
19
- pipeline.enable_tiny_auto_encoder()
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.route('/generate', methods=['POST'])
22
  def generate_image():
23
  try:
24
- # Get prompt from request
25
  data = request.get_json()
26
  prompt = data.get('prompt', 'A futuristic cityscape at sunset, cyberpunk style, 8k')
27
  width = data.get('width', 512)
28
  height = data.get('height', 512)
29
- num_inference_steps = data.get('num_inference_steps', 1)
30
  guidance_scale = data.get('guidance_scale', 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Generate image
33
- image = pipeline(
34
  prompt=prompt,
35
  width=width,
36
  height=height,
@@ -50,7 +85,9 @@ def generate_image():
50
  download_name='generated_image.png'
51
  )
52
  except Exception as e:
 
53
  return jsonify({'error': str(e)}), 500
54
 
55
  if __name__ == '__main__':
56
- app.run(host='0.0.0.0', port=7860)
 
 
1
  import os
2
  from flask import Flask, request, jsonify, send_file
3
  from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipeline
4
+ from diffusers import LoraLoaderMixin
5
  from PIL import Image
6
  import io
7
  import torch
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  app = Flask(__name__)
15
 
16
+ # Set cache directories
17
+ os.environ["HF_HOME"] = "/app/cache/huggingface"
18
+ os.environ["MPLCONFIGDIR"] = "/app/matplotlib_cache"
19
+ os.environ["OPENVINO_TELEMETRY_DIR"] = "/app/openvino_cache"
20
+
21
+ # Ensure cache directories exist
22
+ for cache_dir in ["/app/cache/huggingface", "/app/matplotlib_cache", "/app/openvino_cache"]:
23
+ os.makedirs(cache_dir, exist_ok=True)
24
 
25
+ # Load the base pre-converted OpenVINO SDXL model
26
+ base_model_id = "rupeshs/hyper-sd-sdxl-1-step-openvino-int8"
27
+ try:
28
+ pipeline = OVStableDiffusionPipeline.from_pretrained(
29
+ base_model_id,
30
+ ov_config={"CACHE_DIR": "/app/cache/openvino"},
31
+ device="CPU"
32
+ )
33
+ pipeline.enable_tiny_auto_encoder()
34
+ logger.info("Base model loaded successfully")
35
+ except Exception as e:
36
+ logger.error(f"Failed to load base model: {str(e)}")
37
+ raise
38
 
39
  @app.route('/generate', methods=['POST'])
40
  def generate_image():
41
  try:
42
+ # Get parameters from request
43
  data = request.get_json()
44
  prompt = data.get('prompt', 'A futuristic cityscape at sunset, cyberpunk style, 8k')
45
  width = data.get('width', 512)
46
  height = data.get('height', 512)
47
+ num_inference_steps = data.get('num_inference_steps', 4)
48
  guidance_scale = data.get('guidance_scale', 1.0)
49
+ lora_model_id = data.get('lora_model_id', None)
50
+ lora_weight = data.get('lora_weight', 0.8)
51
+
52
+ # Load LoRA weights if specified
53
+ local_pipeline = pipeline
54
+ if lora_model_id:
55
+ try:
56
+ local_pipeline = LoraLoaderMixin.load_lora_weights(
57
+ local_pipeline,
58
+ lora_model_id,
59
+ lora_scale=lora_weight,
60
+ cache_dir="/app/cache/huggingface"
61
+ )
62
+ logger.info(f"LoRA model {lora_model_id} loaded successfully")
63
+ except Exception as e:
64
+ logger.error(f"Failed to load LoRA model: {str(e)}")
65
+ return jsonify({'error': f"Failed to load LoRA model: {str(e)}"}), 400
66
 
67
  # Generate image
68
+ image = local_pipeline(
69
  prompt=prompt,
70
  width=width,
71
  height=height,
 
85
  download_name='generated_image.png'
86
  )
87
  except Exception as e:
88
+ logger.error(f"Image generation failed: {str(e)}")
89
  return jsonify({'error': str(e)}), 500
90
 
91
  if __name__ == '__main__':
92
+ port = int(os.getenv('PORT', 7860))
93
+ app.run(host='0.0.0.0', port=port)