|
|
|
|
|
""" |
|
|
Gradio app for Sanskrit text transcription using Qwen2.5-VL model |
|
|
Based on quick_test_improved.py |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import os |
|
|
import logging |
|
|
import spaces |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
model_path = 'diabolic6045/Sanskrit-Qwen2.5-VL-7B-Instruct-OCR' |
|
|
|
|
|
logger.info("Loading processor...") |
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
logger.info("Loading Sanskrit OCR model...") |
|
|
|
|
|
device_map = "auto" if torch.cuda.is_available() else "cpu" |
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map=device_map |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
logger.info(f"Model loaded on device: {device}") |
|
|
|
|
|
def check_model_status(): |
|
|
"""Check if model is loaded and ready""" |
|
|
try: |
|
|
if model is not None and processor is not None: |
|
|
return "β
Model loaded and ready" |
|
|
else: |
|
|
return "β³ Model not loaded yet" |
|
|
except Exception as e: |
|
|
return f"β Model error: {str(e)}" |
|
|
|
|
|
@spaces.GPU |
|
|
def transcribe_sanskrit(image, custom_prompt, progress=gr.Progress()): |
|
|
"""Gradio interface function for transcription using pre-loaded model""" |
|
|
if image is None: |
|
|
return "Please upload an image first." |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Processing image...") |
|
|
|
|
|
|
|
|
prompt = custom_prompt if custom_prompt.strip() else "Please transcribe the Sanskrit text shown in this image:" |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
progress(0.5, desc="Generating transcription...") |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
use_cache=True, |
|
|
repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids) |
|
|
] |
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
return output_text[0] if output_text else "" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in transcribe_sanskrit: {e}") |
|
|
return f"β Error occurred: {str(e)}\n\nPlease try again or check if the model files are properly loaded." |
|
|
|
|
|
def create_gradio_interface(): |
|
|
"""Create and configure the Gradio interface""" |
|
|
|
|
|
with gr.Blocks( |
|
|
title="Sanskrit Text Transcription", |
|
|
theme=gr.themes.Soft() |
|
|
) as app: |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="main-header"> |
|
|
<h1>ποΈ Sanskrit Text Transcription</h1> |
|
|
<p>Upload an image containing Sanskrit text and get an accurate transcription using the specialized Sanskrit OCR model</p> |
|
|
<p><strong>π Powered by ZeroGPU:</strong> Dynamic GPU allocation for efficient processing</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Upload Image") |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="Sanskrit Text Image", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
gr.Markdown("### Custom Prompt (Optional)") |
|
|
custom_prompt = gr.Textbox( |
|
|
label="Custom transcription prompt", |
|
|
placeholder="Please transcribe the Sanskrit text shown in this image:", |
|
|
lines=2, |
|
|
value="Please transcribe the Sanskrit text shown in this image:" |
|
|
) |
|
|
|
|
|
transcribe_btn = gr.Button( |
|
|
"ποΈ Transcribe Sanskrit Text", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### Instructions: |
|
|
1. Upload an image containing Sanskrit text |
|
|
2. Optionally modify the prompt for better results |
|
|
3. Click the transcribe button |
|
|
4. View the transcribed text below |
|
|
""") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Transcription Result") |
|
|
output_text = gr.Textbox( |
|
|
label="Transcribed Sanskrit Text", |
|
|
lines=10, |
|
|
max_lines=20, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
gr.Markdown("### Model Information") |
|
|
model_status = gr.Textbox( |
|
|
label="Model Status", |
|
|
value="Checking...", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
check_status_btn = gr.Button("π Check Model Status", size="sm") |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Model:** diabolic6045/Sanskrit-Qwen2.5-VL-7B-Instruct-OCR |
|
|
|
|
|
**Features:** |
|
|
- Multimodal vision-language model |
|
|
- Pre-trained specifically for Sanskrit OCR |
|
|
- Supports various Sanskrit scripts |
|
|
- High accuracy Sanskrit text transcription |
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=transcribe_sanskrit, |
|
|
inputs=[image_input, custom_prompt], |
|
|
outputs=output_text, |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
image_input.change( |
|
|
fn=transcribe_sanskrit, |
|
|
inputs=[image_input, custom_prompt], |
|
|
outputs=output_text, |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
check_status_btn.click( |
|
|
fn=check_model_status, |
|
|
outputs=model_status |
|
|
) |
|
|
|
|
|
|
|
|
app.load( |
|
|
fn=check_model_status, |
|
|
outputs=model_status |
|
|
) |
|
|
|
|
|
return app |
|
|
|
|
|
def main(): |
|
|
"""Main function to launch the Gradio app""" |
|
|
logger.info("Starting Sanskrit Transcription Gradio App...") |
|
|
|
|
|
|
|
|
app = create_gradio_interface() |
|
|
|
|
|
|
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
max_threads=4 |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|