SVECTOR-OFFICIAL commited on
Commit
204c402
·
verified ·
1 Parent(s): 26e01cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -15,15 +15,30 @@ def load_model():
15
  if tok is None or model is None:
16
  print("Loading model...")
17
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- MID,
20
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
21
- device_map="auto" if torch.cuda.is_available() else None,
22
- trust_remote_code=True,
23
- )
24
- if not torch.cuda.is_available():
25
- model = model.to('cpu')
26
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return tok, model
28
 
29
  @spaces.GPU(duration=60)
@@ -34,6 +49,7 @@ def caption_image(image, custom_prompt=None):
34
  try:
35
  # Load model if not already loaded
36
  tok, model = load_model()
 
37
  # Convert image to RGB if needed
38
  if image.mode != "RGB":
39
  image = image.convert("RGB")
@@ -58,16 +74,20 @@ def caption_image(image, custom_prompt=None):
58
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
59
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
60
 
 
 
 
 
61
  # Insert IMAGE token id at placeholder position
62
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
63
- input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device)
64
- attention_mask = torch.ones_like(input_ids, device=model.device)
65
 
66
  # Preprocess image using model's vision tower
67
  px = model.get_vision_tower().image_processor(
68
  images=image, return_tensors="pt"
69
  )["pixel_values"]
70
- px = px.to(model.device, dtype=model.dtype)
71
 
72
  # Generate caption
73
  with torch.no_grad():
@@ -92,7 +112,9 @@ def caption_image(image, custom_prompt=None):
92
  return response
93
 
94
  except Exception as e:
95
- return f"Error generating caption: {str(e)}"
 
 
96
 
97
  # Create Gradio interface
98
  with gr.Blocks(title="Fal-2 Image Captioning") as demo:
 
15
  if tok is None or model is None:
16
  print("Loading model...")
17
  tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
18
+
19
+ # Determine device and dtype
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
+
23
+ # Load model without device_map for CPU, or with proper device_map for CUDA
24
+ if torch.cuda.is_available():
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MID,
27
+ torch_dtype=dtype,
28
+ device_map="auto",
29
+ trust_remote_code=True,
30
+ )
31
+ else:
32
+ # For CPU: load directly to CPU without device_map
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ MID,
35
+ torch_dtype=dtype,
36
+ trust_remote_code=True,
37
+ )
38
+ model = model.to(device)
39
+
40
+ model.eval() # Set to evaluation mode
41
+ print(f"Model loaded successfully on {device}!")
42
  return tok, model
43
 
44
  @spaces.GPU(duration=60)
 
49
  try:
50
  # Load model if not already loaded
51
  tok, model = load_model()
52
+
53
  # Convert image to RGB if needed
54
  if image.mode != "RGB":
55
  image = image.convert("RGB")
 
74
  pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
75
  post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
76
 
77
+ # Get model device and dtype
78
+ device = next(model.parameters()).device
79
+ dtype = next(model.parameters()).dtype
80
+
81
  # Insert IMAGE token id at placeholder position
82
  img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
83
+ input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(device)
84
+ attention_mask = torch.ones_like(input_ids, device=device)
85
 
86
  # Preprocess image using model's vision tower
87
  px = model.get_vision_tower().image_processor(
88
  images=image, return_tensors="pt"
89
  )["pixel_values"]
90
+ px = px.to(device, dtype=dtype)
91
 
92
  # Generate caption
93
  with torch.no_grad():
 
112
  return response
113
 
114
  except Exception as e:
115
+ import traceback
116
+ error_detail = traceback.format_exc()
117
+ return f"Error generating caption: {str(e)}\n\nDetails:\n{error_detail}"
118
 
119
  # Create Gradio interface
120
  with gr.Blocks(title="Fal-2 Image Captioning") as demo: