JasperHaozhe commited on
Commit
5ffc27e
·
verified ·
1 Parent(s): 8a89c69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -25,7 +25,7 @@ VLM_QUANTIZATION_4BIT = False # Load VLM in 4-bit to save memory
25
  VLM_QUANTIZATION_8BIT = False # Load VLM in 8-bit to save memory (mutually exclusive with 4-bit)
26
 
27
  MODEL_ID = "JasperHaozhe/RationalRewards-Both-Demo"
28
- FLUX_MODEL_ID = "AlekseyCalvin/Flux_Kontext_Dev_fp8_scaled_diffusers" # "black-forest-labs/FLUX.1-Kontext-dev"
29
 
30
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
31
 
@@ -52,10 +52,10 @@ model_kwargs = {
52
  }
53
 
54
  # If VLM_MAX_MEMORY is set or using quantization, use device_map
55
- if VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT:
56
- model_kwargs["device_map"] = "auto"
57
- if VLM_MAX_MEMORY:
58
- model_kwargs["max_memory"] = VLM_MAX_MEMORY
59
 
60
  model = AutoModelForImageTextToText.from_pretrained(
61
  MODEL_ID,
@@ -63,10 +63,10 @@ model = AutoModelForImageTextToText.from_pretrained(
63
  )
64
 
65
  # Only manually move to CPU/eval if NOT using device_map/quantization (which handles placement)
66
- if not (VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT):
67
- model.to("cpu").eval()
68
- else:
69
- model.eval()
70
 
71
  # Load Flux Pipeline
72
  flux_pipeline = FluxKontextPipeline.from_pretrained(
@@ -74,7 +74,7 @@ flux_pipeline = FluxKontextPipeline.from_pretrained(
74
  torch_dtype=torch.bfloat16
75
  )
76
  # Fix VAE precision for Flux to avoid artifacts
77
- # flux_pipeline.vae.to(dtype=torch.float32)
78
  # flux_pipeline.enable_attention_slicing() # Enable attention slicing to save memory during inference
79
  # Assume we can load both models simultaneously (User request)
80
  # No CPU offloading logic here.
@@ -474,6 +474,7 @@ def model_inference(task_type, instruction_text, image1, image2, image3, progres
474
  os.makedirs("generated_images", exist_ok=True)
475
  generated_image_path = f"generated_images/flux_edit_{timestamp}.png"
476
  generated_image.save(generated_image_path)
 
477
 
478
  except Exception as e:
479
  yield f"Error generating image: {str(e)}", None
@@ -502,8 +503,8 @@ def model_inference(task_type, instruction_text, image1, image2, image3, progres
502
  messages = [{"role": "user", "content": content}]
503
 
504
  # Ensure model is on CUDA/device for evaluation (VLM handles its own placement via device_map if set)
505
- if not (VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT):
506
- model.to(device_vlm)
507
 
508
  # Generate and stream text
509
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
25
  VLM_QUANTIZATION_8BIT = False # Load VLM in 8-bit to save memory (mutually exclusive with 4-bit)
26
 
27
  MODEL_ID = "JasperHaozhe/RationalRewards-Both-Demo"
28
+ FLUX_MODEL_ID = "yuvraj108c/FLUX.1-Kontext-dev" # "black-forest-labs/FLUX.1-Kontext-dev"
29
 
30
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
31
 
 
52
  }
53
 
54
  # If VLM_MAX_MEMORY is set or using quantization, use device_map
55
+ # if VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT:
56
+ # model_kwargs["device_map"] = "auto"
57
+ # if VLM_MAX_MEMORY:
58
+ # model_kwargs["max_memory"] = VLM_MAX_MEMORY
59
 
60
  model = AutoModelForImageTextToText.from_pretrained(
61
  MODEL_ID,
 
63
  )
64
 
65
  # Only manually move to CPU/eval if NOT using device_map/quantization (which handles placement)
66
+ # if not (VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT):
67
+ # model.to("cpu").eval()
68
+ # else:
69
+ model.eval()
70
 
71
  # Load Flux Pipeline
72
  flux_pipeline = FluxKontextPipeline.from_pretrained(
 
74
  torch_dtype=torch.bfloat16
75
  )
76
  # Fix VAE precision for Flux to avoid artifacts
77
+ flux_pipeline.vae.to(dtype=torch.float32)
78
  # flux_pipeline.enable_attention_slicing() # Enable attention slicing to save memory during inference
79
  # Assume we can load both models simultaneously (User request)
80
  # No CPU offloading logic here.
 
474
  os.makedirs("generated_images", exist_ok=True)
475
  generated_image_path = f"generated_images/flux_edit_{timestamp}.png"
476
  generated_image.save(generated_image_path)
477
+ print(f">>>> generated: {generated_image_path}")
478
 
479
  except Exception as e:
480
  yield f"Error generating image: {str(e)}", None
 
503
  messages = [{"role": "user", "content": content}]
504
 
505
  # Ensure model is on CUDA/device for evaluation (VLM handles its own placement via device_map if set)
506
+ # if not (VLM_MAX_MEMORY or VLM_QUANTIZATION_4BIT or VLM_QUANTIZATION_8BIT):
507
+ model.to(device_vlm)
508
 
509
  # Generate and stream text
510
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)