andrejrad commited on
Commit
09f8ba4
·
verified ·
1 Parent(s): a2b2314

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -2,22 +2,38 @@ import os, json, re
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
5
- from transformers import AutoProcessor, AutoModelForCausalLM
6
 
7
  MODEL_ID = os.environ.get("MODEL_ID", "GrassData/cliptagger-12b")
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
12
 
13
- # Load processor & model
14
- processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  token=HF_TOKEN,
18
  torch_dtype=DTYPE,
19
  device_map="auto",
20
- trust_remote_code=True
21
  )
22
 
23
  # Prompts (system + user, as given)
 
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
5
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
6
 
7
  MODEL_ID = os.environ.get("MODEL_ID", "GrassData/cliptagger-12b")
8
+ BASE_PROCESSOR_ID = os.environ.get("BASE_PROCESSOR_ID", "google/gemma-3-12b-it")
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
10
 
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
 
14
+ # ---- Load processor (from base) + model (from your FT) ----
15
+ try:
16
+ # Processor comes from base VLM repo (has preprocessor_config.json)
17
+ processor = AutoProcessor.from_pretrained(
18
+ BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True
19
+ )
20
+ except Exception as e:
21
+ raise RuntimeError(f"Failed to load processor from {BASE_PROCESSOR_ID}: {e}")
22
+
23
+ # Optional: get a fast tokenizer if processor doesn't expose one
24
+ tokenizer = getattr(processor, "tokenizer", None)
25
+ if tokenizer is None:
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ BASE_PROCESSOR_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
28
+ )
29
+
30
+ # Your fine-tuned weights
31
  model = AutoModelForCausalLM.from_pretrained(
32
  MODEL_ID,
33
  token=HF_TOKEN,
34
  torch_dtype=DTYPE,
35
  device_map="auto",
36
+ trust_remote_code=True,
37
  )
38
 
39
  # Prompts (system + user, as given)