Geraldine commited on
Commit
96a1472
·
verified ·
1 Parent(s): 67b3520

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -23,6 +23,7 @@ from transformers import (
23
  Qwen3VLForConditionalGeneration,
24
  Qwen2_5_VLForConditionalGeneration,
25
  AutoModelForCausalLM,
 
26
  AutoProcessor,
27
  AutoModel,
28
  AutoTokenizer,
@@ -170,6 +171,14 @@ model_m = load_model_with_attention_fallback(
170
  MODEL_ID_A = "mistralai/Ministral-3-8B-Instruct-2512"
171
  ALBERT_API_URL = "https://albert.api.etalab.gouv.fr/v1/chat/completions"
172
 
 
 
 
 
 
 
 
 
173
  MODEL_MAP = {
174
  "Nanonets-OCR2-3B": (processor_v, model_v),
175
  "LightOnOCR-2-1B": (processor_y, model_y),
@@ -177,6 +186,7 @@ MODEL_MAP = {
177
  "Qwen3-VL-4B-Instruct": (processor_m, model_m),
178
  "Qwen2-VL-OCR-2B": (processor_x, model_x),
179
  "Ministral-3-8B-Instruct-2512": (None, MODEL_ID_A),
 
180
  }
181
 
182
  MODEL_CHOICES = list(MODEL_MAP.keys())
@@ -578,6 +588,37 @@ def generate_image(model_name, text, image, max_new_tokens, temperature, top_p,
578
  inputs.pop("token_type_ids", None)
579
  inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  generation_kwargs = {
582
  **inputs,
583
  "streamer": streamer,
 
23
  Qwen3VLForConditionalGeneration,
24
  Qwen2_5_VLForConditionalGeneration,
25
  AutoModelForCausalLM,
26
+ AutoModelForMultimodalLM,
27
  AutoProcessor,
28
  AutoModel,
29
  AutoTokenizer,
 
171
  MODEL_ID_A = "mistralai/Ministral-3-8B-Instruct-2512"
172
  ALBERT_API_URL = "https://albert.api.etalab.gouv.fr/v1/chat/completions"
173
 
174
+ MODEL_ID_G = "google/gemma-4-E4B-it"
175
+ processor_g = AutoProcessor.from_pretrained(MODEL_ID_G)
176
+ model_g = load_model_with_attention_fallback(
177
+ AutoModelForMultimodalLM,
178
+ MODEL_ID_G,
179
+ torch_dtype="auto"
180
+ ).to(device).eval()
181
+
182
  MODEL_MAP = {
183
  "Nanonets-OCR2-3B": (processor_v, model_v),
184
  "LightOnOCR-2-1B": (processor_y, model_y),
 
186
  "Qwen3-VL-4B-Instruct": (processor_m, model_m),
187
  "Qwen2-VL-OCR-2B": (processor_x, model_x),
188
  "Ministral-3-8B-Instruct-2512": (None, MODEL_ID_A),
189
+ "gemma-4-E4B-it": (processor_g, model_g),
190
  }
191
 
192
  MODEL_CHOICES = list(MODEL_MAP.keys())
 
588
  inputs.pop("token_type_ids", None)
589
  inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
590
 
591
+ generation_kwargs = {
592
+ **inputs,
593
+ "streamer": streamer,
594
+ "max_new_tokens": int(max_new_tokens),
595
+ "do_sample": True,
596
+ "temperature": float(temperature),
597
+ "top_p": float(top_p),
598
+ "top_k": int(top_k),
599
+ "repetition_penalty": float(repetition_penalty),
600
+ }
601
+ elif model_name == "gemma-4-E4B-it":
602
+ messages = [
603
+ {
604
+ "role": "user",
605
+ "content": [
606
+ {"type": "image", "image": image},
607
+ {"type": "text", "text": text},
608
+ ],
609
+ }
610
+ ]
611
+
612
+ inputs = processor.apply_chat_template(
613
+ messages,
614
+ tokenize=True,
615
+ add_generation_prompt=True,
616
+ return_dict=True,
617
+ return_tensors="pt",
618
+ enable_thinking=False,
619
+ )
620
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
621
+
622
  generation_kwargs = {
623
  **inputs,
624
  "streamer": streamer,