Antharee commited on
Commit
adfd1fe
·
verified ·
1 Parent(s): eaa18ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -41
app.py CHANGED
@@ -1,42 +1,13 @@
1
- import os
2
- import torch
3
- import gradio as gr
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForVision2Seq
6
-
7
- hf_token = os.getenv("HUGGINGFACE_TOKEN")
8
-
9
- token_args = {}
10
- if hf_token:
11
- token_args = {"use_auth_token": hf_token}
12
-
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- processor = AutoProcessor.from_pretrained("scb10x/typhoon-ocr-3b", **token_args)
16
- model = AutoModelForVision2Seq.from_pretrained(
17
- "scb10x/typhoon-ocr-3b",
18
- torch_dtype=torch.float16,
19
- device_map="auto",
20
- **token_args
21
- )
22
-
23
- def ocr_infer(image):
24
- if image is None:
25
- return "❌ Error: No image provided"
26
- try:
27
- image = image.convert("RGB")
28
- inputs = processor(images=image, return_tensors="pt")
29
-
30
- if inputs is None or "pixel_values" not in inputs:
31
- return "❌ Error: Invalid processor output"
32
-
33
- inputs = {k: v.to(device) for k, v in inputs.items()}
34
- generated_ids = model.generate(**inputs, max_new_tokens=256)
35
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
- return result
37
-
38
- except Exception as e:
39
- return f"❌ Error during inference: {e}"
40
-
41
- iface = gr.Interface(fn=ocr_infer, inputs=gr.Image(type="pil"), outputs="text", title="Typhoon OCR 3B")
42
- iface.launch()
 
 
 
 
1
  from PIL import Image
2
+ import requests
3
+ from io import BytesIO
4
+
5
+ image_url = "https://raw.githubusercontent.com/scb10x/typhoon-ocr/main/assets/test-image.png"
6
+ response = requests.get(image_url)
7
+ image = Image.open(BytesIO(response.content)).convert("RGB")
8
+
9
+ inputs = processor(images=image, return_tensors="pt")
10
+ inputs = {k: v.to(device) for k, v in inputs.items()}
11
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
12
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
13
+ print(result)