import torch import gradio as gr from PIL import Image import numpy as np import janus # noqa: F401 from janus.models import VLChatProcessor from transformers import AutoModelForCausalLM from peft import PeftModel from transformers.modeling_utils import PreTrainedModel BASE_MODEL = "deepseek-ai/Janus-Pro-1B" ADAPTER_REPO = "Branis333/Janus_grade_final" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 processor = VLChatProcessor.from_pretrained(BASE_MODEL) _original_adjust_tied = PreTrainedModel._adjust_tied_keys_with_tied_pointers def _safe_adjust_tied_keys_with_tied_pointers(self, tied_weights_keys_by_pointers): existing = getattr(self, "all_tied_weights_keys", None) if existing is None: tied = getattr(self, "_tied_weights_keys", []) or [] self.all_tied_weights_keys = {k: True for k in tied} elif not hasattr(existing, "keys"): self.all_tied_weights_keys = {k: True for k in existing} return _original_adjust_tied(self, tied_weights_keys_by_pointers) PreTrainedModel._adjust_tied_keys_with_tied_pointers = _safe_adjust_tied_keys_with_tied_pointers def _load_base_model(): return AutoModelForCausalLM.from_pretrained( BASE_MODEL, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=False, device_map=None, _fast_init=False, ) try: base_model = _load_base_model() except RuntimeError as err: if "Tensor.item() cannot be called on meta tensors" not in str(err): raise original_linspace = torch.linspace class _ScalarFloat: def __init__(self, value): self._value = float(value) def item(self): return self._value def _safe_linspace(*args, **kwargs): start = kwargs.get("start", args[0] if len(args) > 0 else 0.0) end = kwargs.get("end", args[1] if len(args) > 1 else 1.0) steps = kwargs.get("steps", args[2] if len(args) > 2 else 100) values = np.linspace(float(start), float(end), int(steps), dtype=np.float32) return [_ScalarFloat(v) for v in values] torch.linspace = _safe_linspace try: base_model = _load_base_model() finally: torch.linspace = original_linspace base_model = base_model.to(device) base_model.language_model = PeftModel.from_pretrained( base_model.language_model, ADAPTER_REPO, ) base_model.eval() def transcribe(image: Image.Image, prompt: str): if image is None: return "Please upload an image." if not prompt or not prompt.strip(): return "Please enter a prompt." image = image.convert("RGB") conversation = [ { "role": "User", "content": "\n" + prompt.strip(), "images": ["uploaded_image"], }, {"role": "Assistant", "content": ""}, ] prepare_inputs = processor( conversations=conversation, images=[image], force_batchify=True, ).to(base_model.device) if hasattr(prepare_inputs, "pixel_values") and prepare_inputs.pixel_values is not None: try: vision_dtype = next(base_model.vision_model.parameters()).dtype except Exception: vision_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 prepare_inputs.pixel_values = prepare_inputs.pixel_values.to(base_model.device, dtype=vision_dtype) with torch.no_grad(): inputs_embeds = base_model.prepare_inputs_embeds(**prepare_inputs) output_ids = base_model.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=processor.tokenizer.eos_token_id, bos_token_id=processor.tokenizer.bos_token_id, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=220, do_sample=False, use_cache=True, ) return processor.tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True).strip() demo = gr.Interface( fn=transcribe, inputs=[ gr.Image(type="pil", label="Handwritten Image"), gr.Textbox(label="Prompt"), ], outputs=gr.Textbox(label="Transcription"), title="Janus Handwritten OCR (LoRA)", description=f"Base: {BASE_MODEL} + LoRA adapter: {ADAPTER_REPO}", ) if __name__ == "__main__": demo.launch(ssr_mode=False)