Branis333's picture
Update app.py
2136843 verified
import torch
import gradio as gr
from PIL import Image
import janus # noqa: F401
from janus.models import VLChatProcessor
from transformers import AutoModelForCausalLM
from peft import PeftModel
BASE_MODEL = "deepseek-ai/Janus-Pro-1B"
ADAPTER_REPO = "Branis333/Janus_Grade"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
processor = VLChatProcessor.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=False,
).to(device)
base_model.language_model = PeftModel.from_pretrained(
base_model.language_model,
ADAPTER_REPO,
)
base_model.eval()
def transcribe(image: Image.Image, prompt: str):
if image is None:
return "Please upload an image."
if not prompt or not prompt.strip():
prompt = "Read the handwritten text in this image and transcribe it exactly."
image = image.convert("RGB")
conversation = [
{
"role": "User",
"content": "<image_placeholder>\n" + prompt,
"images": ["uploaded_image"],
},
{"role": "Assistant", "content": ""},
]
prepare_inputs = processor(
conversations=conversation,
images=[image],
force_batchify=True,
).to(base_model.device)
if hasattr(prepare_inputs, "pixel_values") and prepare_inputs.pixel_values is not None:
try:
vision_dtype = next(base_model.vision_model.parameters()).dtype
except Exception:
vision_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
prepare_inputs.pixel_values = prepare_inputs.pixel_values.to(base_model.device, dtype=vision_dtype)
with torch.no_grad():
inputs_embeds = base_model.prepare_inputs_embeds(**prepare_inputs)
outputs = base_model.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=processor.tokenizer.eos_token_id,
bos_token_id=processor.tokenizer.bos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
max_new_tokens=220,
do_sample=False,
use_cache=True,
)
return processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True).strip()
demo = gr.Interface(
fn=transcribe,
inputs=[
gr.Image(type="pil", label="Handwritten Image"),
gr.Textbox(label="Prompt"), # no default prompt
],
outputs=gr.Textbox(label="Transcription"),
title="Janus Handwritten OCR (LoRA)",
description="Base: deepseek-ai/Janus-Pro-1B + LoRA adapter: Branis333/Janus_Grade",
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)