| |
|
|
| import sys |
| import os |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import login |
| from dotenv import load_dotenv |
|
|
| |
| project_root = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, project_root) |
|
|
| |
| try: |
| import spaces |
| print("'spaces' module imported successfully.") |
| except ImportError: |
| print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") |
| class DummySpaces: |
| def GPU(self, *args, **kwargs): |
| def decorator(func): |
| print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") |
| return func |
| return decorator |
| spaces = DummySpaces() |
|
|
| |
| load_dotenv() |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| if not HF_TOKEN: |
| raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.") |
|
|
| print("--- Logging in to Hugging Face Hub ---") |
| login(token=HF_TOKEN) |
|
|
|
|
| |
|
|
| MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN" |
|
|
| print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---") |
|
|
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| print("GPU detected. Using CUDA.") |
| else: |
| device = torch.device("cpu") |
| print("No GPU detected. Using CPU.") |
|
|
| dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 |
| print(f"--- Using dtype: {dtype} ---") |
|
|
| print(f"--- Loading tokenizer from Hub: {MODEL_NAME} ---") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| trust_remote_code=True |
| ) |
| |
| SPECIAL_MARKER = "<|LOC_0|>" |
| print(f"--- Using special marker for overlap: {SPECIAL_MARKER} ---") |
| except Exception as e: |
| raise RuntimeError(f"FATAL: Could not load tokenizer from the Hub. Error: {e}") |
|
|
| print(f"--- Loading Model with PyTorch from Hub: {MODEL_NAME} ---") |
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=dtype, |
| trust_remote_code=True |
| ).to(device) |
| model.eval() |
| print("--- Model Loaded Successfully ---") |
| except Exception as e: |
| raise RuntimeError(f"FATAL: Could not load model from the Hub. Error: {e}") |
|
|
|
|
| |
| def chunk_text(text: str, max_size: int) -> list[str]: |
| """Splits text into chunks, trying to break at sentence endings.""" |
| if not text: return [] |
| chunks, start_index = [], 0 |
| while start_index < len(text): |
| end_index = start_index + max_size |
| if end_index >= len(text): |
| chunks.append(text[start_index:]) |
| break |
| split_pos = text.rfind('.', start_index, end_index) |
| if split_pos != -1: |
| chunk, start_index = text[start_index : split_pos + 1], split_pos + 1 |
| else: |
| chunk, start_index = text[start_index:end_index], end_index |
| chunks.append(chunk.strip()) |
| return [c for c in chunks if c] |
|
|
|
|
| |
| def do_translation(text_to_translate: str) -> str: |
| """A clean helper function to run a single translation.""" |
| if not text_to_translate.strip(): |
| return "" |
| messages = [{"role": "user", "content": text_to_translate}] |
| prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device) |
| |
| generated_ids = model.generate( |
| **model_inputs, |
| max_new_tokens=2048, |
| do_sample=True, temperature=0.7, top_p=0.95, top_k=50 |
| ) |
| |
| input_token_len = model_inputs.input_ids.shape[1] |
| output_ids = generated_ids[0][input_token_len:].tolist() |
| |
| return tokenizer.decode(output_ids, skip_special_tokens=False).strip() |
|
|
|
|
| |
| @spaces.GPU |
| @torch.no_grad() |
| def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str: |
| """ |
| Processes text by chunks, using a special token to mark the overlap |
| for clean and reliable removal. |
| """ |
| progress(0, desc="Starting...") |
| print("--- Inference with special token context method started ---") |
| if not input_text or not input_text.strip(): |
| return "Input text is empty. Please enter some text to translate." |
|
|
| progress(0.1, desc="Chunking Text...") |
| text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text] |
| num_chunks = len(text_chunks) |
| print(f"Processing {num_chunks} chunk(s).") |
|
|
| all_results = [] |
| english_context = "" |
|
|
| for i, chunk in enumerate(text_chunks): |
| progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}") |
|
|
| final_translation_for_chunk = "" |
| if english_context: |
| |
| prompt_with_marker = f"{english_context} {SPECIAL_MARKER} {chunk}" |
| full_translation = do_translation(prompt_with_marker) |
|
|
| |
| marker_position = full_translation.find(SPECIAL_MARKER) |
|
|
| if marker_position != -1: |
| |
| print("Special marker found in output. Removing overlap.") |
| start_of_clean_text = marker_position + len(SPECIAL_MARKER) |
| final_translation_for_chunk = full_translation[start_of_clean_text:].lstrip() |
| else: |
| |
| |
| print(f"Warning: Marker '{SPECIAL_MARKER}' not found in translation. Overlap may remain.") |
| final_translation_for_chunk = full_translation |
| else: |
| |
| final_translation_for_chunk = do_translation(chunk) |
| |
| all_results.append(final_translation_for_chunk) |
| print(f"Chunk {i+1} processed successfully.") |
|
|
| if context_words > 0: |
| words = chunk.split() |
| english_context = " ".join(words[-context_words:]) |
|
|
| progress(0.95, desc="Reassembling Results...") |
| |
| full_output = " ".join(all_results).replace(SPECIAL_MARKER, "") |
|
|
| progress(1.0, desc="Done!") |
| return full_output |
|
|
| |
| print("\n--- Initializing Gradio Interface ---") |
|
|
| app = gr.Interface( |
| fn=translate_with_chunks, |
| inputs=[ |
| gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."), |
| gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"), |
| gr.Slider( |
| minimum=0, |
| maximum=50, |
| value=15, |
| step=5, |
| label="Context Overlap (Source Words)", |
| info="Number of English words from the end of the previous chunk to provide as context for the next one. Ensures consistency." |
| ) |
| ], |
| outputs=gr.Textbox(lines=15, label="Model Output", interactive=False), |
| title="ERNIE 4.5 Context-Aware Translator", |
| description="Processes long text using a special token marker to ensure high-quality, consistent translations without duplication.", |
| allow_flagging="never" |
| ) |
|
|
| if __name__ == "__main__": |
| app.queue().launch() |