Spaces:
Sleeping
Sleeping
| import fastai | |
| from fastai.vision.all import * | |
| import timm | |
| from PIL import Image | |
| from pathlib import Path | |
| from os import path | |
| from tqdm.auto import tqdm | |
| from urllib.error import HTTPError, URLError | |
| def search_images_ddg(term, max_images=200): | |
| "Search for `term` with DuckDuckGo and return a unique urls of about `max_images` images" | |
| assert max_images<1000 | |
| url = 'https://duckduckgo.com/' | |
| res = urlread(url,data={'q':term}) | |
| searchObj = re.search(r'vqd=([\d-]+)\&', res) | |
| assert searchObj | |
| requestUrl = url + 'i.js' | |
| params = dict(l='us-en', o='json', q=term, vqd=searchObj.group(1), f=',,,', p='1', v7exp='a') | |
| urls,data = set(),{'next':1} | |
| headers = dict(referer='https://duckduckgo.com/') | |
| while len(urls)<max_images and 'next' in data: | |
| try: | |
| res = urlread(requestUrl, data=params, headers=headers) | |
| data = json.loads(res) if res else {} | |
| urls.update(L(data['results']).itemgot('image')) | |
| requestUrl = url + data['next'] | |
| except (URLError,HTTPError): pass | |
| time.sleep(1) | |
| return L(urls)[:max_images] | |
| tool_names = "resistor", "bipolar transistor", "mosfet", "capacitor", "inductor", "wire", "led", "diode", "thermistor", "switch", "battery", "hammer", "screwdriver", "scissors", "wrench", "mallet", "axe" | |
| path = Path("data", "tools") | |
| path.absolute() | |
| if not path.exists(): | |
| path.mkdir(parents=True) | |
| for o in tqdm(tool_names): | |
| dest = (path/o) | |
| dest.mkdir(exist_ok=True) | |
| results = search_images_ddg(f'{o}', max_images=20) | |
| download_images(dest, urls=results, n_workers=2) | |
| fns = get_image_files(path) | |
| failed = verify_images(fns) | |
| failed.map(Path.unlink) | |
| data_config = DataBlock( | |
| blocks=(ImageBlock, CategoryBlock), | |
| get_items=get_image_files, | |
| splitter=RandomSplitter(seed=42), | |
| get_y=parent_label, | |
| item_tfms=Resize(224) | |
| ) | |
| dls = data_config.dataloaders(path) | |
| connectors = data_config.new(item_tfms=RandomResizedCrop(224, min_scale=0.5), batch_tfms=aug_transforms()) | |
| dls = connectors.dataloaders(path) | |
| learn = vision_learner(dls, 'convnext_small.fb_in22k_ft_in1k', metrics=error_rate) | |
| learn.fine_tune(7, freeze_epochs=1) | |
| learn.path = Path('.') | |
| learn.export() | |