tuan2308 commited on
Commit
e26a2a9
·
verified ·
1 Parent(s): 66686bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -3
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
-
3
- # Bỏ qua yêu cầu torchvision (ZeroGPU không cài sẵn).
4
- os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
5
 
6
  import spaces
7
  import torch
 
8
  from diffusers import QwenImageEditPlusPipeline
9
  from PIL import Image
10
  import gradio as gr
@@ -24,7 +24,32 @@ FORCE_CPU = bool(int(os.getenv("FORCE_CPU", "1")))
24
  PREFER_GPU = bool(int(os.getenv("PREFER_GPU", "0")))
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _build_pipeline(device: str, dtype: torch.dtype):
 
 
28
  pipe = QwenImageEditPlusPipeline.from_pretrained(
29
  HF_BASE_MODEL,
30
  torch_dtype=dtype,
@@ -47,6 +72,35 @@ def _build_pipeline(device: str, dtype: torch.dtype):
47
  return pipe
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @spaces.GPU # bắt buộc cho ZeroGPU
51
  def load_pipeline():
52
  global EXEC_DEVICE
@@ -59,6 +113,7 @@ def load_pipeline():
59
  try:
60
  pipe = _build_pipeline(device, dtype)
61
  EXEC_DEVICE = device
 
62
  return pipe
63
  except Exception as exc:
64
  # GPU worker thường abort vì OOM. Fallback CPU để không crash app.
 
1
  import os
2
+ import subprocess
3
+ import sys
 
4
 
5
  import spaces
6
  import torch
7
+ from optimization import optimize_pipeline_
8
  from diffusers import QwenImageEditPlusPipeline
9
  from PIL import Image
10
  import gradio as gr
 
24
  PREFER_GPU = bool(int(os.getenv("PREFER_GPU", "0")))
25
 
26
 
27
+ def ensure_torchvision():
28
+ """
29
+ Qwen2VLProcessor yêu cầu torchvision. Thử import, nếu thiếu sẽ cài đặt phiên bản khớp torch.
30
+ """
31
+ try:
32
+ import torchvision # noqa: F401
33
+ return
34
+ except ImportError:
35
+ torch_version = torch.__version__.split("+")[0]
36
+ try:
37
+ subprocess.check_call(
38
+ [sys.executable, "-m", "pip", "install", f"torchvision=={torch_version}"],
39
+ stdout=subprocess.DEVNULL,
40
+ stderr=subprocess.DEVNULL,
41
+ )
42
+ import torchvision # noqa: F401
43
+ except Exception as exc: # pragma: no cover - chỉ chạy trên hạ tầng Spaces
44
+ raise ImportError(
45
+ "Torchvision is required for Qwen2VLProcessor. "
46
+ "Please add a matching torchvision to requirements (e.g. pip install torchvision==torch_version)."
47
+ ) from exc
48
+
49
+
50
  def _build_pipeline(device: str, dtype: torch.dtype):
51
+ ensure_torchvision()
52
+
53
  pipe = QwenImageEditPlusPipeline.from_pretrained(
54
  HF_BASE_MODEL,
55
  torch_dtype=dtype,
 
72
  return pipe
73
 
74
 
75
+ def maybe_optimize_pipeline(pipe):
76
+ """
77
+ Áp dụng AOTI tối ưu hóa trên GPU (nếu đang chạy CUDA).
78
+ Dùng input dummy nhỏ để tránh tốn VRAM, fallback im lặng nếu lỗi.
79
+ """
80
+ if EXEC_DEVICE != "cuda":
81
+ return pipe
82
+ try:
83
+ dummy = Image.new("RGB", (256, 256))
84
+ generator = torch.Generator(device="cuda").manual_seed(0)
85
+ optimize_pipeline_(
86
+ pipe,
87
+ image=[dummy, dummy],
88
+ prompt="warmup",
89
+ negative_prompt=" ",
90
+ num_inference_steps=1,
91
+ true_cfg_scale=1.0,
92
+ guidance_scale=1.0,
93
+ num_images_per_prompt=1,
94
+ generator=generator,
95
+ width=256,
96
+ height=256,
97
+ )
98
+ except Exception:
99
+ # Nếu tối ưu thất bại (thường do bộ nhớ), giữ pipeline gốc để tiếp tục chạy.
100
+ pass
101
+ return pipe
102
+
103
+
104
  @spaces.GPU # bắt buộc cho ZeroGPU
105
  def load_pipeline():
106
  global EXEC_DEVICE
 
113
  try:
114
  pipe = _build_pipeline(device, dtype)
115
  EXEC_DEVICE = device
116
+ pipe = maybe_optimize_pipeline(pipe)
117
  return pipe
118
  except Exception as exc:
119
  # GPU worker thường abort vì OOM. Fallback CPU để không crash app.