File size: 1,787 Bytes
2cb0e9a
b844571
 
4d642ec
 
 
 
2cb0e9a
4d642ec
b844571
 
 
 
 
4d642ec
 
 
b844571
4d642ec
 
 
b844571
4d642ec
 
 
b844571
4d642ec
 
b844571
4d642ec
b844571
 
 
4d642ec
b844571
 
 
 
 
 
 
 
4d642ec
b844571
 
4d642ec
b844571
 
 
4d642ec
 
 
 
b844571
4d642ec
 
2cb0e9a
4d642ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
from duckduckgo_search import DDGS
from fastcore.all import *
from fastdownload import download_url
from fastai.vision.all import *
from PIL import Image
from pathlib import Path

# Define search function using DuckDuckGo
ddgs = DDGS()
def search_images(term, max_images=30):
    print(f"Searching for '{term}'")
    return L(ddgs.images(term, max_results=max_images)).itemgot('image')

# Create a folder for storing images
path = Path('images')
path.mkdir(exist_ok=True)

# Download example image: beaver
beaver_url = search_images('beaver photo', max_images=1)[0]
download_url(beaver_url, path/'beaver.jpg', show_progress=False)

# Download another example image: platypus
platypus_url = search_images('platypus photo', max_images=1)[0]
download_url(platypus_url, path/'platypus.jpg', show_progress=False)

# Show a thumbnail of the platypus image
Image.open(path/'platypus.jpg').thumbnail((256,256))

# Remove any corrupt images
failed = verify_images(get_image_files(path))
failed.map(Path.unlink)

# Prepare DataLoaders (make sure `images/` has subfolders of labeled images)
dls = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=[Resize(192, method='squish')]
).dataloaders(path, bs=32)

# Show a sample batch
dls.show_batch(max_n=6)

# Train a model
learn = vision_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(3)

# Prediction function
def predict_species(img):
    is_sheep, _, probs = learn.predict(img)
    return f"This looks like a: {is_sheep}. Probability it's a beaver: {probs[0]:.4f}"

# Define Gradio interface
demo = gr.Interface(fn=predict_species, inputs=gr.Image(type="pil"), outputs="text")
demo.launch()