Spaces:
Running on Zero
Running on Zero
update app
Browse files
app.py
CHANGED
|
@@ -222,67 +222,112 @@ def calc_timeout_image(model_name, text, image, max_new_tokens, temperature, top
|
|
| 222 |
|
| 223 |
@spaces.GPU(duration=calc_timeout_image)
|
| 224 |
def generate_image(model_name, text, image, max_new_tokens=1024, temperature=0.6, top_p=0.9, top_k=50, repetition_penalty=1.2, gpu_timeout=60):
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
"
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
thread.start()
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
else:
|
| 281 |
-
|
|
|
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
def noop():
|
|
@@ -669,8 +714,16 @@ function init() {
|
|
| 669 |
const sb = document.getElementById('sb-run-state');
|
| 670 |
if (sb) sb.textContent = 'Done';
|
| 671 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
window.__showLoader = showLoader;
|
| 673 |
window.__hideLoader = hideLoader;
|
|
|
|
| 674 |
|
| 675 |
function flashPromptError() {
|
| 676 |
promptInput.classList.add('error-flash');
|
|
@@ -845,7 +898,12 @@ function init() {
|
|
| 845 |
showLoader();
|
| 846 |
setTimeout(() => {
|
| 847 |
const gradioBtn = document.getElementById('gradio-run-btn');
|
| 848 |
-
if (!gradioBtn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
const btn = gradioBtn.querySelector('button');
|
| 850 |
if (btn) btn.click(); else gradioBtn.click();
|
| 851 |
}, 180);
|
|
@@ -961,6 +1019,10 @@ function watchOutputs() {
|
|
| 961 |
|
| 962 |
let lastText = '';
|
| 963 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
function syncOutput() {
|
| 965 |
const el = resultContainer.querySelector('textarea') || resultContainer.querySelector('input');
|
| 966 |
if (!el) return;
|
|
@@ -969,7 +1031,15 @@ function watchOutputs() {
|
|
| 969 |
lastText = val;
|
| 970 |
outArea.value = val;
|
| 971 |
outArea.scrollTop = outArea.scrollHeight;
|
| 972 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
}
|
| 974 |
}
|
| 975 |
|
|
@@ -1178,18 +1248,21 @@ with gr.Blocks() as demo:
|
|
| 1178 |
return None
|
| 1179 |
|
| 1180 |
def run_ocr(model_name, text, image_b64, max_new_tokens_v, temperature_v, top_p_v, top_k_v, repetition_penalty_v, gpu_timeout_v):
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
|
|
|
|
|
|
|
|
|
| 1193 |
|
| 1194 |
demo.load(fn=noop, inputs=None, outputs=None, js=gallery_js)
|
| 1195 |
demo.load(fn=noop, inputs=None, outputs=None, js=wire_outputs_js)
|
|
|
|
| 222 |
|
| 223 |
@spaces.GPU(duration=calc_timeout_image)
|
| 224 |
def generate_image(model_name, text, image, max_new_tokens=1024, temperature=0.6, top_p=0.9, top_k=50, repetition_penalty=1.2, gpu_timeout=60):
|
| 225 |
+
buffer = ""
|
| 226 |
+
try:
|
| 227 |
+
if not model_name or model_name not in MODEL_MAP:
|
| 228 |
+
yield "[ERROR] Please select a valid model."
|
| 229 |
+
return
|
| 230 |
+
if image is None:
|
| 231 |
+
yield "[ERROR] Please upload an image."
|
| 232 |
+
return
|
| 233 |
+
if not text or not str(text).strip():
|
| 234 |
+
yield "[ERROR] Please enter your OCR/query instruction."
|
| 235 |
+
return
|
| 236 |
+
if len(str(text)) > MAX_INPUT_TOKEN_LENGTH * 8:
|
| 237 |
+
yield "[ERROR] Query is too long. Please shorten your input."
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
processor, model = MODEL_MAP[model_name]
|
| 241 |
+
images = [image]
|
| 242 |
+
|
| 243 |
+
if model_name == "SmolDocling-256M-preview":
|
| 244 |
+
if "OTSL" in text or "code" in text:
|
| 245 |
+
images = [add_random_padding(img) for img in images]
|
| 246 |
+
if "OCR at text at" in text or "Identify element" in text or "formula" in text:
|
| 247 |
+
text = normalize_values(text, target_max=500)
|
| 248 |
+
|
| 249 |
+
messages = [{
|
| 250 |
+
"role": "user",
|
| 251 |
+
"content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
|
| 252 |
+
}]
|
| 253 |
+
|
| 254 |
+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
| 255 |
+
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
|
| 256 |
+
|
| 257 |
+
streamer = TextIteratorStreamer(
|
| 258 |
+
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
|
| 259 |
+
skip_prompt=True,
|
| 260 |
+
skip_special_tokens=True
|
| 261 |
+
)
|
| 262 |
|
| 263 |
+
generation_error = {"error": None}
|
|
|
|
| 264 |
|
| 265 |
+
generation_kwargs = {
|
| 266 |
+
**inputs,
|
| 267 |
+
"streamer": streamer,
|
| 268 |
+
"max_new_tokens": int(max_new_tokens),
|
| 269 |
+
"temperature": float(temperature),
|
| 270 |
+
"top_p": float(top_p),
|
| 271 |
+
"top_k": int(top_k),
|
| 272 |
+
"repetition_penalty": float(repetition_penalty),
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
def _run_generation():
|
| 276 |
+
try:
|
| 277 |
+
model.generate(**generation_kwargs)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
generation_error["error"] = e
|
| 280 |
+
try:
|
| 281 |
+
streamer.end()
|
| 282 |
+
except Exception:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
thread = Thread(target=_run_generation, daemon=True)
|
| 286 |
+
thread.start()
|
| 287 |
+
|
| 288 |
+
for new_text in streamer:
|
| 289 |
+
buffer += new_text.replace("<|im_end|>", "")
|
| 290 |
+
yield buffer
|
| 291 |
+
|
| 292 |
+
thread.join(timeout=1.0)
|
| 293 |
+
|
| 294 |
+
if generation_error["error"] is not None:
|
| 295 |
+
err_msg = f"[ERROR] Inference failed: {str(generation_error['error'])}"
|
| 296 |
+
if buffer.strip():
|
| 297 |
+
yield buffer + "\n\n" + err_msg
|
| 298 |
+
else:
|
| 299 |
+
yield err_msg
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
if model_name == "SmolDocling-256M-preview":
|
| 303 |
+
cleaned_output = buffer.replace("<end_of_utterance>", "").strip()
|
| 304 |
+
if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
|
| 305 |
+
try:
|
| 306 |
+
if "<chart>" in cleaned_output:
|
| 307 |
+
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
|
| 308 |
+
cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
|
| 309 |
+
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
|
| 310 |
+
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
|
| 311 |
+
markdown_output = doc.export_to_markdown()
|
| 312 |
+
yield markdown_output
|
| 313 |
+
except Exception as e:
|
| 314 |
+
yield f"[ERROR] Post-processing failed: {str(e)}"
|
| 315 |
+
return
|
| 316 |
+
else:
|
| 317 |
+
if cleaned_output.strip():
|
| 318 |
+
yield cleaned_output
|
| 319 |
+
else:
|
| 320 |
+
yield "[ERROR] No output was generated."
|
| 321 |
else:
|
| 322 |
+
if not buffer.strip():
|
| 323 |
+
yield "[ERROR] No output was generated."
|
| 324 |
|
| 325 |
+
except Exception as e:
|
| 326 |
+
yield f"[ERROR] {str(e)}"
|
| 327 |
+
finally:
|
| 328 |
+
gc.collect()
|
| 329 |
+
if torch.cuda.is_available():
|
| 330 |
+
torch.cuda.empty_cache()
|
| 331 |
|
| 332 |
|
| 333 |
def noop():
|
|
|
|
| 714 |
const sb = document.getElementById('sb-run-state');
|
| 715 |
if (sb) sb.textContent = 'Done';
|
| 716 |
}
|
| 717 |
+
function setRunErrorState() {
|
| 718 |
+
const l = document.getElementById('output-loader');
|
| 719 |
+
if (l) l.classList.remove('active');
|
| 720 |
+
const sb = document.getElementById('sb-run-state');
|
| 721 |
+
if (sb) sb.textContent = 'Error';
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
window.__showLoader = showLoader;
|
| 725 |
window.__hideLoader = hideLoader;
|
| 726 |
+
window.__setRunErrorState = setRunErrorState;
|
| 727 |
|
| 728 |
function flashPromptError() {
|
| 729 |
promptInput.classList.add('error-flash');
|
|
|
|
| 898 |
showLoader();
|
| 899 |
setTimeout(() => {
|
| 900 |
const gradioBtn = document.getElementById('gradio-run-btn');
|
| 901 |
+
if (!gradioBtn) {
|
| 902 |
+
setRunErrorState();
|
| 903 |
+
if (outputArea) outputArea.value = '[ERROR] Run button not found.';
|
| 904 |
+
showToast('Run button not found', 'error');
|
| 905 |
+
return;
|
| 906 |
+
}
|
| 907 |
const btn = gradioBtn.querySelector('button');
|
| 908 |
if (btn) btn.click(); else gradioBtn.click();
|
| 909 |
}, 180);
|
|
|
|
| 1019 |
|
| 1020 |
let lastText = '';
|
| 1021 |
|
| 1022 |
+
function isErrorText(val) {
|
| 1023 |
+
return typeof val === 'string' && val.trim().startsWith('[ERROR]');
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
function syncOutput() {
|
| 1027 |
const el = resultContainer.querySelector('textarea') || resultContainer.querySelector('input');
|
| 1028 |
if (!el) return;
|
|
|
|
| 1031 |
lastText = val;
|
| 1032 |
outArea.value = val;
|
| 1033 |
outArea.scrollTop = outArea.scrollHeight;
|
| 1034 |
+
|
| 1035 |
+
if (val.trim()) {
|
| 1036 |
+
if (isErrorText(val)) {
|
| 1037 |
+
if (window.__setRunErrorState) window.__setRunErrorState();
|
| 1038 |
+
if (window.__showToast) window.__showToast('OCR failed', 'error');
|
| 1039 |
+
} else {
|
| 1040 |
+
if (window.__hideLoader) window.__hideLoader();
|
| 1041 |
+
}
|
| 1042 |
+
}
|
| 1043 |
}
|
| 1044 |
}
|
| 1045 |
|
|
|
|
| 1248 |
return None
|
| 1249 |
|
| 1250 |
def run_ocr(model_name, text, image_b64, max_new_tokens_v, temperature_v, top_p_v, top_k_v, repetition_penalty_v, gpu_timeout_v):
|
| 1251 |
+
try:
|
| 1252 |
+
image = b64_to_pil(image_b64)
|
| 1253 |
+
yield from generate_image(
|
| 1254 |
+
model_name=model_name,
|
| 1255 |
+
text=text,
|
| 1256 |
+
image=image,
|
| 1257 |
+
max_new_tokens=max_new_tokens_v,
|
| 1258 |
+
temperature=temperature_v,
|
| 1259 |
+
top_p=top_p_v,
|
| 1260 |
+
top_k=top_k_v,
|
| 1261 |
+
repetition_penalty=repetition_penalty_v,
|
| 1262 |
+
gpu_timeout=gpu_timeout_v,
|
| 1263 |
+
)
|
| 1264 |
+
except Exception as e:
|
| 1265 |
+
yield f"[ERROR] {str(e)}"
|
| 1266 |
|
| 1267 |
demo.load(fn=noop, inputs=None, outputs=None, js=gallery_js)
|
| 1268 |
demo.load(fn=noop, inputs=None, outputs=None, js=wire_outputs_js)
|