Sandhya commited on
Commit
8b252d4
·
1 Parent(s): 136a3f5

First commit

Browse files
Files changed (2) hide show
  1. app.py +40 -0
  2. requirments.txt +108 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import torch
3
+ from torch import nn
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import onnxruntime
7
+ import numpy as np
8
+ import torch.nn.functional as F
9
+ from safetensors.torch import save_file,load_file,safe_open
10
+ import numpy as np
11
+
12
+ def predict(img_path,model=None):
13
+ if model is None:
14
+ pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
15
+ model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
16
+ class_names=pretrained_weights_resnet18.meta["categories"]
17
+ transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
18
+ if isinstance(img_path,np.ndarray):
19
+ image=Image.fromarray(img_path).convert("RGB")
20
+ else:
21
+ image=Image.open(img_path).convert("RGB")
22
+ img_transform=transform(image).unsqueeze(0)
23
+
24
+ model.eval()
25
+ with torch.inference_mode():
26
+ logit=model(img_transform)
27
+ pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
28
+ predict_dict={}
29
+ for i in range(len(class_names)):
30
+ predict_dict[class_names[i]]=float(pred_prob[i])
31
+
32
+ return predict_dict
33
+
34
+
35
+ import numpy as np
36
+ import gradio as gr
37
+
38
+ demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
39
+ if __name__ == "__main__":
40
+ demo.launch()
requirments.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ annotated-types==0.7.0
3
+ anyio==4.9.0
4
+ asttokens==3.0.0
5
+ certifi==2025.1.31
6
+ charset-normalizer==3.4.1
7
+ click==8.1.8
8
+ colorama==0.4.6
9
+ coloredlogs==15.0.1
10
+ comm==0.2.2
11
+ contourpy==1.3.2
12
+ cycler==0.12.1
13
+ debugpy==1.8.14
14
+ decorator==5.2.1
15
+ dnspython==2.7.0
16
+ email_validator==2.2.0
17
+ executing==2.2.0
18
+ fastapi==0.115.12
19
+ fastapi-cli==0.0.7
20
+ ffmpy==0.5.0
21
+ filelock==3.18.0
22
+ flatbuffers==25.2.10
23
+ fonttools==4.57.0
24
+ fsspec==2025.3.2
25
+ gradio==5.25.2
26
+ gradio_client==1.8.0
27
+ groovy==0.1.2
28
+ h11==0.14.0
29
+ httpcore==1.0.7
30
+ httptools==0.6.4
31
+ httpx==0.28.1
32
+ huggingface-hub==0.30.2
33
+ humanfriendly==10.0
34
+ idna==3.10
35
+ ipykernel==6.29.5
36
+ ipython==9.1.0
37
+ ipython_pygments_lexers==1.1.1
38
+ jedi==0.19.2
39
+ Jinja2==3.1.6
40
+ jupyter_client==8.6.3
41
+ jupyter_core==5.7.2
42
+ kiwisolver==1.4.8
43
+ mangum==0.19.0
44
+ markdown-it-py==3.0.0
45
+ MarkupSafe==3.0.2
46
+ matplotlib==3.10.1
47
+ matplotlib-inline==0.1.7
48
+ mdurl==0.1.2
49
+ mpmath==1.3.0
50
+ nest-asyncio==1.6.0
51
+ networkx==3.4.2
52
+ numpy==2.2.4
53
+ onnxruntime==1.21.1
54
+ orjson==3.10.16
55
+ packaging==24.2
56
+ pandas==2.2.3
57
+ parso==0.8.4
58
+ pillow==11.2.1
59
+ platformdirs==4.3.7
60
+ prompt_toolkit==3.0.51
61
+ protobuf==6.30.2
62
+ psutil==7.0.0
63
+ pure_eval==0.2.3
64
+ pydantic==2.11.3
65
+ pydantic_core==2.33.1
66
+ pydub==0.25.1
67
+ Pygments==2.19.1
68
+ pyparsing==3.2.3
69
+ pyreadline3==3.5.4
70
+ python-dateutil==2.9.0.post0
71
+ python-dotenv==1.1.0
72
+ python-multipart==0.0.20
73
+ pytz==2025.2
74
+ pywin32==310
75
+ PyYAML==6.0.2
76
+ pyzmq==26.4.0
77
+ regex==2024.11.6
78
+ requests==2.32.3
79
+ rich==14.0.0
80
+ rich-toolkit==0.14.1
81
+ ruff==0.11.6
82
+ safehttpx==0.1.6
83
+ safetensors==0.5.3
84
+ semantic-version==2.10.0
85
+ shellingham==1.5.4
86
+ six==1.17.0
87
+ sniffio==1.3.1
88
+ stack-data==0.6.3
89
+ starlette==0.46.1
90
+ sympy==1.13.1
91
+ tokenizers==0.21.1
92
+ tomlkit==0.13.2
93
+ torch==2.6.0
94
+ torchinfo==1.8.0
95
+ torchvision==0.21.0
96
+ tornado==6.4.2
97
+ tqdm==4.67.1
98
+ traitlets==5.14.3
99
+ transformers==4.51.3
100
+ typer==0.15.2
101
+ typing-inspection==0.4.0
102
+ typing_extensions==4.13.2
103
+ tzdata==2025.2
104
+ urllib3==2.4.0
105
+ uvicorn==0.34.0
106
+ watchfiles==1.0.5
107
+ wcwidth==0.2.13
108
+ websockets==15.0.1