Branis333's picture
Update app.py
96f09b6 verified
import torch
import gradio as gr
from PIL import Image
import numpy as np
import janus # noqa: F401
from janus.models import VLChatProcessor
from transformers import AutoModelForCausalLM
from peft import PeftModel
from transformers.modeling_utils import PreTrainedModel
BASE_MODEL = "deepseek-ai/Janus-Pro-1B"
ADAPTER_REPO = "Branis333/Janus_grade_final"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = VLChatProcessor.from_pretrained(BASE_MODEL)
_original_adjust_tied = PreTrainedModel._adjust_tied_keys_with_tied_pointers
def _safe_adjust_tied_keys_with_tied_pointers(self, tied_weights_keys_by_pointers):
existing = getattr(self, "all_tied_weights_keys", None)
if existing is None:
tied = getattr(self, "_tied_weights_keys", []) or []
self.all_tied_weights_keys = {k: True for k in tied}
elif not hasattr(existing, "keys"):
self.all_tied_weights_keys = {k: True for k in existing}
return _original_adjust_tied(self, tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = _safe_adjust_tied_keys_with_tied_pointers
def _load_base_model():
return AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=False,
device_map=None,
_fast_init=False,
)
try:
base_model = _load_base_model()
except RuntimeError as err:
if "Tensor.item() cannot be called on meta tensors" not in str(err):
raise
original_linspace = torch.linspace
class _ScalarFloat:
def __init__(self, value):
self._value = float(value)
def item(self):
return self._value
def _safe_linspace(*args, **kwargs):
start = kwargs.get("start", args[0] if len(args) > 0 else 0.0)
end = kwargs.get("end", args[1] if len(args) > 1 else 1.0)
steps = kwargs.get("steps", args[2] if len(args) > 2 else 100)
values = np.linspace(float(start), float(end), int(steps), dtype=np.float32)
return [_ScalarFloat(v) for v in values]
torch.linspace = _safe_linspace
try:
base_model = _load_base_model()
finally:
torch.linspace = original_linspace
base_model = base_model.to(device)
base_model.language_model = PeftModel.from_pretrained(
base_model.language_model,
ADAPTER_REPO,
)
base_model.eval()
def transcribe(image: Image.Image, prompt: str):
if image is None:
return "Please upload an image."
if not prompt or not prompt.strip():
return "Please enter a prompt."
image = image.convert("RGB")
conversation = [
{
"role": "User",
"content": "<image_placeholder>\n" + prompt.strip(),
"images": ["uploaded_image"],
},
{"role": "Assistant", "content": ""},
]
prepare_inputs = processor(
conversations=conversation,
images=[image],
force_batchify=True,
).to(base_model.device)
if hasattr(prepare_inputs, "pixel_values") and prepare_inputs.pixel_values is not None:
try:
vision_dtype = next(base_model.vision_model.parameters()).dtype
except Exception:
vision_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
prepare_inputs.pixel_values = prepare_inputs.pixel_values.to(base_model.device, dtype=vision_dtype)
with torch.no_grad():
inputs_embeds = base_model.prepare_inputs_embeds(**prepare_inputs)
output_ids = base_model.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=processor.tokenizer.eos_token_id,
bos_token_id=processor.tokenizer.bos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
max_new_tokens=220,
do_sample=False,
use_cache=True,
)
return processor.tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True).strip()
demo = gr.Interface(
fn=transcribe,
inputs=[
gr.Image(type="pil", label="Handwritten Image"),
gr.Textbox(label="Prompt"),
],
outputs=gr.Textbox(label="Transcription"),
title="Janus Handwritten OCR (LoRA)",
description=f"Base: {BASE_MODEL} + LoRA adapter: {ADAPTER_REPO}",
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)