Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import janus # noqa: F401 | |
| from janus.models import VLChatProcessor | |
| from transformers import AutoModelForCausalLM | |
| from peft import PeftModel | |
| BASE_MODEL = "deepseek-ai/Janus-Pro-1B" | |
| ADAPTER_REPO = "Branis333/Janus_Grade" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| processor = VLChatProcessor.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=False, | |
| ).to(device) | |
| base_model.language_model = PeftModel.from_pretrained( | |
| base_model.language_model, | |
| ADAPTER_REPO, | |
| ) | |
| base_model.eval() | |
| def transcribe(image: Image.Image, prompt: str): | |
| if image is None: | |
| return "Please upload an image." | |
| if not prompt or not prompt.strip(): | |
| prompt = "Read the handwritten text in this image and transcribe it exactly." | |
| image = image.convert("RGB") | |
| conversation = [ | |
| { | |
| "role": "User", | |
| "content": "<image_placeholder>\n" + prompt, | |
| "images": ["uploaded_image"], | |
| }, | |
| {"role": "Assistant", "content": ""}, | |
| ] | |
| prepare_inputs = processor( | |
| conversations=conversation, | |
| images=[image], | |
| force_batchify=True, | |
| ).to(base_model.device) | |
| if hasattr(prepare_inputs, "pixel_values") and prepare_inputs.pixel_values is not None: | |
| try: | |
| vision_dtype = next(base_model.vision_model.parameters()).dtype | |
| except Exception: | |
| vision_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| prepare_inputs.pixel_values = prepare_inputs.pixel_values.to(base_model.device, dtype=vision_dtype) | |
| with torch.no_grad(): | |
| inputs_embeds = base_model.prepare_inputs_embeds(**prepare_inputs) | |
| outputs = base_model.language_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=prepare_inputs.attention_mask, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| bos_token_id=processor.tokenizer.bos_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| max_new_tokens=220, | |
| do_sample=False, | |
| use_cache=True, | |
| ) | |
| return processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True).strip() | |
| demo = gr.Interface( | |
| fn=transcribe, | |
| inputs=[ | |
| gr.Image(type="pil", label="Handwritten Image"), | |
| gr.Textbox(label="Prompt"), # no default prompt | |
| ], | |
| outputs=gr.Textbox(label="Transcription"), | |
| title="Janus Handwritten OCR (LoRA)", | |
| description="Base: deepseek-ai/Janus-Pro-1B + LoRA adapter: Branis333/Janus_Grade", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |