Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import janus # noqa: F401 | |
| from janus.models import VLChatProcessor | |
| from transformers import AutoModelForCausalLM | |
| from peft import PeftModel | |
| from transformers.modeling_utils import PreTrainedModel | |
| BASE_MODEL = "deepseek-ai/Janus-Pro-1B" | |
| ADAPTER_REPO = "Branis333/Janus_grade_final" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| processor = VLChatProcessor.from_pretrained(BASE_MODEL) | |
| _original_adjust_tied = PreTrainedModel._adjust_tied_keys_with_tied_pointers | |
| def _safe_adjust_tied_keys_with_tied_pointers(self, tied_weights_keys_by_pointers): | |
| existing = getattr(self, "all_tied_weights_keys", None) | |
| if existing is None: | |
| tied = getattr(self, "_tied_weights_keys", []) or [] | |
| self.all_tied_weights_keys = {k: True for k in tied} | |
| elif not hasattr(existing, "keys"): | |
| self.all_tied_weights_keys = {k: True for k in existing} | |
| return _original_adjust_tied(self, tied_weights_keys_by_pointers) | |
| PreTrainedModel._adjust_tied_keys_with_tied_pointers = _safe_adjust_tied_keys_with_tied_pointers | |
| def _load_base_model(): | |
| return AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=False, | |
| device_map=None, | |
| _fast_init=False, | |
| ) | |
| try: | |
| base_model = _load_base_model() | |
| except RuntimeError as err: | |
| if "Tensor.item() cannot be called on meta tensors" not in str(err): | |
| raise | |
| original_linspace = torch.linspace | |
| class _ScalarFloat: | |
| def __init__(self, value): | |
| self._value = float(value) | |
| def item(self): | |
| return self._value | |
| def _safe_linspace(*args, **kwargs): | |
| start = kwargs.get("start", args[0] if len(args) > 0 else 0.0) | |
| end = kwargs.get("end", args[1] if len(args) > 1 else 1.0) | |
| steps = kwargs.get("steps", args[2] if len(args) > 2 else 100) | |
| values = np.linspace(float(start), float(end), int(steps), dtype=np.float32) | |
| return [_ScalarFloat(v) for v in values] | |
| torch.linspace = _safe_linspace | |
| try: | |
| base_model = _load_base_model() | |
| finally: | |
| torch.linspace = original_linspace | |
| base_model = base_model.to(device) | |
| base_model.language_model = PeftModel.from_pretrained( | |
| base_model.language_model, | |
| ADAPTER_REPO, | |
| ) | |
| base_model.eval() | |
| def transcribe(image: Image.Image, prompt: str): | |
| if image is None: | |
| return "Please upload an image." | |
| if not prompt or not prompt.strip(): | |
| return "Please enter a prompt." | |
| image = image.convert("RGB") | |
| conversation = [ | |
| { | |
| "role": "User", | |
| "content": "<image_placeholder>\n" + prompt.strip(), | |
| "images": ["uploaded_image"], | |
| }, | |
| {"role": "Assistant", "content": ""}, | |
| ] | |
| prepare_inputs = processor( | |
| conversations=conversation, | |
| images=[image], | |
| force_batchify=True, | |
| ).to(base_model.device) | |
| if hasattr(prepare_inputs, "pixel_values") and prepare_inputs.pixel_values is not None: | |
| try: | |
| vision_dtype = next(base_model.vision_model.parameters()).dtype | |
| except Exception: | |
| vision_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| prepare_inputs.pixel_values = prepare_inputs.pixel_values.to(base_model.device, dtype=vision_dtype) | |
| with torch.no_grad(): | |
| inputs_embeds = base_model.prepare_inputs_embeds(**prepare_inputs) | |
| output_ids = base_model.language_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=prepare_inputs.attention_mask, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| bos_token_id=processor.tokenizer.bos_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| max_new_tokens=220, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| return processor.tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True).strip() | |
| demo = gr.Interface( | |
| fn=transcribe, | |
| inputs=[ | |
| gr.Image(type="pil", label="Handwritten Image"), | |
| gr.Textbox(label="Prompt"), | |
| ], | |
| outputs=gr.Textbox(label="Transcription"), | |
| title="Janus Handwritten OCR (LoRA)", | |
| description=f"Base: {BASE_MODEL} + LoRA adapter: {ADAPTER_REPO}", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |