Spaces:
Running
Running
| import os | |
| import time | |
| from typing import Dict, Any | |
| import numpy as np | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import torch | |
| from utils.predict_bounding_boxes import predict_bounding_boxes | |
| from utils.manga_ocr_utils import get_text_from_image | |
| from utils.translate_manga import translate_manga | |
| from utils.process_contour import process_contour | |
| from utils.write_text_on_image import add_text | |
| MODEL_PATH = "./model_creation/runs/detect/train5/weights/best.pt" | |
| # Detect GPU availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| if device == "cuda": | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| object_detection_model = YOLO(MODEL_PATH) | |
| object_detection_model.to(device) | |
| def extract_text_from_regions( | |
| image: np.ndarray, source_lang: str, target_lang: str, results: list | |
| ) -> Dict[str, Any]: | |
| for result in results: | |
| x1, y1, x2, y2, _, _ = result | |
| detected_image = image[int(y1) : int(y2), int(x1) : int(x2)] | |
| # Ensure detected_image is uint8 for PIL | |
| if detected_image.dtype != np.uint8: | |
| detected_image = (detected_image * 255).astype(np.uint8) | |
| im = Image.fromarray(detected_image) | |
| text = get_text_from_image(im) | |
| processed_image, cont = process_contour(detected_image) | |
| translated_text = translate_manga( | |
| text, target_lang=target_lang, source_lang=source_lang | |
| ) | |
| if translated_text is None: | |
| translated_text = "Translation failed" | |
| add_text(processed_image, translated_text, cont) | |
| def predict(image: np.ndarray, source_lang: str = "ja-JP", target_lang: str = "en-GB"): | |
| timestamp = str(int(time.time() * 1000000)) # Generate a unique timestamp | |
| temp_filename = f"image_{timestamp}.png" | |
| image_pil = Image.fromarray(image).convert("RGB") | |
| image_pil.save(temp_filename) | |
| try: | |
| np_image = np.array(image_pil) | |
| results = predict_bounding_boxes(object_detection_model, temp_filename) | |
| extract_text_from_regions(np_image, source_lang, target_lang, results) | |
| return np_image | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| return None | |
| finally: | |
| # Clean up the temporary file | |
| if os.path.exists(temp_filename): | |
| try: | |
| os.remove(temp_filename) | |
| except OSError as e: | |
| print(f"Warning: Could not remove temporary file {temp_filename}: {e}") | |