Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,22 +13,19 @@ import pandas as pd
|
|
| 13 |
import difflib
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
|
| 16 |
-
# Define the device
|
| 17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
-
|
| 19 |
# OCR Correction Model
|
| 20 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
| 21 |
|
| 22 |
import torch
|
| 23 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 24 |
|
|
|
|
|
|
|
| 25 |
# Load pre-trained model and tokenizer
|
| 26 |
model_name = "PleIAs/OCRonos-Vintage"
|
| 27 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 28 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 29 |
|
| 30 |
-
# Set the device to GPU if available, otherwise use CPU
|
| 31 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
model.to(device)
|
| 33 |
|
| 34 |
# CSS for formatting
|
|
@@ -169,7 +166,9 @@ def split_text(text, max_tokens=500):
|
|
| 169 |
|
| 170 |
|
| 171 |
# Function to generate text
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
| 174 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 175 |
|
|
@@ -177,9 +176,7 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
| 177 |
torch.set_num_threads(num_threads)
|
| 178 |
|
| 179 |
# Generate text
|
| 180 |
-
|
| 181 |
-
future = executor.submit(
|
| 182 |
-
model.generate,
|
| 183 |
input_ids,
|
| 184 |
max_new_tokens=max_new_tokens,
|
| 185 |
pad_token_id=tokenizer.eos_token_id,
|
|
@@ -188,8 +185,6 @@ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
| 188 |
do_sample=True,
|
| 189 |
temperature=0.7
|
| 190 |
)
|
| 191 |
-
output = future.result()
|
| 192 |
-
|
| 193 |
# Decode and return the generated text
|
| 194 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 195 |
print(result)
|
|
|
|
| 13 |
import difflib
|
| 14 |
from concurrent.futures import ThreadPoolExecutor
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
# OCR Correction Model
|
| 17 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
| 18 |
|
| 19 |
import torch
|
| 20 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 21 |
|
| 22 |
+
device = "cuda"
|
| 23 |
+
|
| 24 |
# Load pre-trained model and tokenizer
|
| 25 |
model_name = "PleIAs/OCRonos-Vintage"
|
| 26 |
model = GPT2LMHeadModel.from_pretrained(model_name)
|
| 27 |
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
| 28 |
|
|
|
|
|
|
|
| 29 |
model.to(device)
|
| 30 |
|
| 31 |
# CSS for formatting
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
# Function to generate text
|
| 169 |
+
@spaces.GPU
|
| 170 |
+
def ocr_correction(prompt, max_new_tokens=500):
|
| 171 |
+
|
| 172 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
| 173 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
| 174 |
|
|
|
|
| 176 |
torch.set_num_threads(num_threads)
|
| 177 |
|
| 178 |
# Generate text
|
| 179 |
+
output = model.generate,
|
|
|
|
|
|
|
| 180 |
input_ids,
|
| 181 |
max_new_tokens=max_new_tokens,
|
| 182 |
pad_token_id=tokenizer.eos_token_id,
|
|
|
|
| 185 |
do_sample=True,
|
| 186 |
temperature=0.7
|
| 187 |
)
|
|
|
|
|
|
|
| 188 |
# Decode and return the generated text
|
| 189 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 190 |
print(result)
|