elderprince commited on
Commit
d74a4be
·
0 Parent(s):

init commit

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +72 -0
  3. requirements.txt +66 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+ import torch
5
+ import re
6
+
7
+ # Load your model
8
+ processor = DonutProcessor.from_pretrained('elderprince/HeR-T')
9
+ model = VisionEncoderDecoderModel.from_pretrained('elderprince/HeR-T')
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # Convert texts to JSON
15
+ def convert_text_to_json(sequence):
16
+ seq = sequence.replace(processor.tokenizer.eos_token,
17
+ "").replace(processor.tokenizer.pad_token, "")
18
+ # Remove first task start token
19
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip()
20
+ # Convert to JSON
21
+ seq = processor.token2json(seq)
22
+ return seq
23
+
24
+ # Preprocessing function
25
+ def preprocess(image):
26
+ # Resize to the model's expected input size
27
+ image = Image.open(image).resize((1200, 1600)).convert("RGB")
28
+ # Convert to tensor and normalize
29
+ image = processor(images=image, return_tensors="pt").pixel_values
30
+ # Ensure the image is on the correct device
31
+ if image.device != device:
32
+ image = image.to(device)
33
+ return image
34
+
35
+ # Prediction function
36
+ def predict(image):
37
+ processed_image = preprocess(image)
38
+ # Prepare decoder inputs
39
+ task_prompt = "<s_herbarium>"
40
+ decoder_input_ids = processor.tokenizer(task_prompt,
41
+ add_special_tokens=False,
42
+ return_tensors="pt").input_ids
43
+ decoder_input_ids = decoder_input_ids.to(device)
44
+ # Generate output
45
+ with torch.no_grad():
46
+ output = model.generate(
47
+ pixel_values=processed_image,
48
+ decoder_input_ids=decoder_input_ids,
49
+ max_length=processor.tokenizer.pad_token_id,
50
+ eos_token_id=processor.tokenizer.eos_token_id,
51
+ use_cache=True,
52
+ num_beams=1,
53
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
54
+ early_stopping=True
55
+ )
56
+ # Decode the output
57
+ output = processor.batch_decode(output.sequences, skip_special_tokens=True)[0]
58
+ # Convert to JSON
59
+ output = convert_text_to_json(output)
60
+ return output
61
+
62
+ # Gradio interface
63
+ demo = gr.Interface(
64
+ fn=predict,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs="text",
67
+ title="Herbarium specimen label Recognition Transformer (HeR-T) Demo",
68
+ description="Upload a single-specimen image to see the model's output.",
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ audioop-lts==0.2.1
5
+ certifi==2025.1.31
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ fastapi==0.115.12
9
+ ffmpy==0.5.0
10
+ filelock==3.18.0
11
+ fsspec==2025.3.2
12
+ gradio==5.23.3
13
+ gradio_client==1.8.0
14
+ groovy==0.1.2
15
+ h11==0.14.0
16
+ httpcore==1.0.7
17
+ httpx==0.28.1
18
+ huggingface-hub==0.30.1
19
+ idna==3.10
20
+ Jinja2==3.1.6
21
+ markdown-it-py==3.0.0
22
+ MarkupSafe==3.0.2
23
+ mdurl==0.1.2
24
+ mpmath==1.3.0
25
+ networkx==3.4.2
26
+ numpy==2.2.4
27
+ orjson==3.10.16
28
+ packaging==24.2
29
+ pandas==2.2.3
30
+ pillow==11.1.0
31
+ pydantic==2.11.1
32
+ pydantic_core==2.33.0
33
+ pydub==0.25.1
34
+ Pygments==2.19.1
35
+ python-dateutil==2.9.0.post0
36
+ python-multipart==0.0.20
37
+ pytz==2025.2
38
+ PyYAML==6.0.2
39
+ regex==2024.11.6
40
+ requests==2.32.3
41
+ rich==14.0.0
42
+ ruff==0.11.2
43
+ safehttpx==0.1.6
44
+ safetensors==0.5.3
45
+ semantic-version==2.10.0
46
+ setuptools==75.8.0
47
+ shellingham==1.5.4
48
+ six==1.17.0
49
+ sniffio==1.3.1
50
+ starlette==0.46.1
51
+ sympy==1.13.1
52
+ tokenizers==0.21.1
53
+ tomlkit==0.13.2
54
+ torch==2.6.0
55
+ torchaudio==2.6.0
56
+ torchvision==0.21.0
57
+ tqdm==4.67.1
58
+ transformers==4.50.3
59
+ typer==0.15.2
60
+ typing-inspection==0.4.0
61
+ typing_extensions==4.13.0
62
+ tzdata==2025.2
63
+ urllib3==2.3.0
64
+ uvicorn==0.34.0
65
+ websockets==15.0.1
66
+ wheel==0.45.1