Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,6 +23,7 @@ from transformers import (
|
|
| 23 |
Qwen3VLForConditionalGeneration,
|
| 24 |
Qwen2_5_VLForConditionalGeneration,
|
| 25 |
AutoModelForCausalLM,
|
|
|
|
| 26 |
AutoProcessor,
|
| 27 |
AutoModel,
|
| 28 |
AutoTokenizer,
|
|
@@ -170,6 +171,14 @@ model_m = load_model_with_attention_fallback(
|
|
| 170 |
MODEL_ID_A = "mistralai/Ministral-3-8B-Instruct-2512"
|
| 171 |
ALBERT_API_URL = "https://albert.api.etalab.gouv.fr/v1/chat/completions"
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
MODEL_MAP = {
|
| 174 |
"Nanonets-OCR2-3B": (processor_v, model_v),
|
| 175 |
"LightOnOCR-2-1B": (processor_y, model_y),
|
|
@@ -177,6 +186,7 @@ MODEL_MAP = {
|
|
| 177 |
"Qwen3-VL-4B-Instruct": (processor_m, model_m),
|
| 178 |
"Qwen2-VL-OCR-2B": (processor_x, model_x),
|
| 179 |
"Ministral-3-8B-Instruct-2512": (None, MODEL_ID_A),
|
|
|
|
| 180 |
}
|
| 181 |
|
| 182 |
MODEL_CHOICES = list(MODEL_MAP.keys())
|
|
@@ -578,6 +588,37 @@ def generate_image(model_name, text, image, max_new_tokens, temperature, top_p,
|
|
| 578 |
inputs.pop("token_type_ids", None)
|
| 579 |
inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
generation_kwargs = {
|
| 582 |
**inputs,
|
| 583 |
"streamer": streamer,
|
|
|
|
| 23 |
Qwen3VLForConditionalGeneration,
|
| 24 |
Qwen2_5_VLForConditionalGeneration,
|
| 25 |
AutoModelForCausalLM,
|
| 26 |
+
AutoModelForMultimodalLM,
|
| 27 |
AutoProcessor,
|
| 28 |
AutoModel,
|
| 29 |
AutoTokenizer,
|
|
|
|
| 171 |
MODEL_ID_A = "mistralai/Ministral-3-8B-Instruct-2512"
|
| 172 |
ALBERT_API_URL = "https://albert.api.etalab.gouv.fr/v1/chat/completions"
|
| 173 |
|
| 174 |
+
MODEL_ID_G = "google/gemma-4-E4B-it"
|
| 175 |
+
processor_g = AutoProcessor.from_pretrained(MODEL_ID_G)
|
| 176 |
+
model_g = load_model_with_attention_fallback(
|
| 177 |
+
AutoModelForMultimodalLM,
|
| 178 |
+
MODEL_ID_G,
|
| 179 |
+
torch_dtype="auto"
|
| 180 |
+
).to(device).eval()
|
| 181 |
+
|
| 182 |
MODEL_MAP = {
|
| 183 |
"Nanonets-OCR2-3B": (processor_v, model_v),
|
| 184 |
"LightOnOCR-2-1B": (processor_y, model_y),
|
|
|
|
| 186 |
"Qwen3-VL-4B-Instruct": (processor_m, model_m),
|
| 187 |
"Qwen2-VL-OCR-2B": (processor_x, model_x),
|
| 188 |
"Ministral-3-8B-Instruct-2512": (None, MODEL_ID_A),
|
| 189 |
+
"gemma-4-E4B-it": (processor_g, model_g),
|
| 190 |
}
|
| 191 |
|
| 192 |
MODEL_CHOICES = list(MODEL_MAP.keys())
|
|
|
|
| 588 |
inputs.pop("token_type_ids", None)
|
| 589 |
inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 590 |
|
| 591 |
+
generation_kwargs = {
|
| 592 |
+
**inputs,
|
| 593 |
+
"streamer": streamer,
|
| 594 |
+
"max_new_tokens": int(max_new_tokens),
|
| 595 |
+
"do_sample": True,
|
| 596 |
+
"temperature": float(temperature),
|
| 597 |
+
"top_p": float(top_p),
|
| 598 |
+
"top_k": int(top_k),
|
| 599 |
+
"repetition_penalty": float(repetition_penalty),
|
| 600 |
+
}
|
| 601 |
+
elif model_name == "gemma-4-E4B-it":
|
| 602 |
+
messages = [
|
| 603 |
+
{
|
| 604 |
+
"role": "user",
|
| 605 |
+
"content": [
|
| 606 |
+
{"type": "image", "image": image},
|
| 607 |
+
{"type": "text", "text": text},
|
| 608 |
+
],
|
| 609 |
+
}
|
| 610 |
+
]
|
| 611 |
+
|
| 612 |
+
inputs = processor.apply_chat_template(
|
| 613 |
+
messages,
|
| 614 |
+
tokenize=True,
|
| 615 |
+
add_generation_prompt=True,
|
| 616 |
+
return_dict=True,
|
| 617 |
+
return_tensors="pt",
|
| 618 |
+
enable_thinking=False,
|
| 619 |
+
)
|
| 620 |
+
inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 621 |
+
|
| 622 |
generation_kwargs = {
|
| 623 |
**inputs,
|
| 624 |
"streamer": streamer,
|