Spaces:
Sleeping
Sleeping
| from transformers import ViTForImageClassification, ViTFeatureExtractor | |
| from PIL import Image | |
| import torch | |
| # Loading in Model | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = ViTForImageClassification.from_pretrained("imjeffhi/pokemon_classifier").to(device) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained('imjeffhi/pokemon_classifier') | |
| def predicted_pokemon(img): | |
| extracted = feature_extractor(images=img, return_tensors='pt').to(device) | |
| predicted_id = model(**extracted).logits.argmax(-1).item() | |
| predicted_pokemon = model.config.id2label[predicted_id] | |
| return predicted_id, predicted_pokemon | |
| """## Platform code | |
| """ | |
| import os | |
| import gradio as gr | |
| import json | |
| BASE_DIR = 'my_project' | |
| os.makedirs(BASE_DIR, exist_ok=True) | |
| IMG_DIR = os.path.join(BASE_DIR, 'images') | |
| os.makedirs(IMG_DIR, exist_ok=True) | |
| JSON_DATA_DIR = os.path.join(BASE_DIR, 'results') | |
| os.makedirs(JSON_DATA_DIR, exist_ok=True) | |
| JSON_USERS_PATH = os.path.join(BASE_DIR, 'users.json') | |
| """### main codes""" | |
| def show_page(page_name, username): | |
| if page_name == 'classifier': | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| elif page_name == 'archive': | |
| archive_msg, archive_df = show_archive(username) | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), archive_msg, archive_df | |
| elif page_name == 'login': | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| def save_result(img: Image.Image, predicted_id, predicted_pokemon, username): | |
| # controlli di sicurezza | |
| if img is None: | |
| return "Error: no image uploaded!" | |
| if predicted_id is None or predicted_id == "": | |
| return "Error: classify first!" | |
| if predicted_pokemon is None or predicted_pokemon == "": | |
| return "Error: classify first!" | |
| if username == 'user' or username == '': | |
| return "Error: sign-in first" | |
| now = datetime.now() | |
| timestamp_file = now.strftime('%Y-%m-%d-%H-%M-%S') | |
| timestamp_json = now.isoformat() | |
| # Salva l'immagine | |
| img_path = os.path.join(IMG_DIR, username, f'{timestamp_file}.jpg') | |
| os.makedirs(os.path.dirname(img_path), exist_ok=True) | |
| img.save(img_path) | |
| json_path = os.path.join(JSON_DATA_DIR, f'{username}.json') | |
| # Leggi JSON esistente | |
| try: | |
| with open(json_path, "r") as f: | |
| results = json.load(f) | |
| except FileNotFoundError: | |
| results = [] | |
| results.append({ | |
| 'timestamp': timestamp_json, | |
| 'image_path': os.path.relpath(img_path, BASE_DIR), | |
| 'predicted_id': int(predicted_id), | |
| 'predicted_pokemon': predicted_pokemon | |
| }) | |
| with open(json_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| return gr.update(value="Results saved successfully!", visible=True) | |
| # laod dell'archivio | |
| def show_archive(username): | |
| json_path = os.path.join(JSON_DATA_DIR, f'{username}.json') | |
| if not os.path.exists(json_path): | |
| return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False) | |
| with open(json_path, "r") as f: | |
| archive_data = json.load(f) | |
| if not archive_data: | |
| return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False) | |
| rows = [] | |
| for row in archive_data: | |
| rows.append([row["timestamp"], | |
| row["predicted_id"], | |
| row["predicted_pokemon"], | |
| row["image_path"]]) | |
| return gr.update(value="", visible=False), gr.update(value=rows, visible=True) | |
| def show_image(evt: gr.SelectData): | |
| img_path = evt.row_value[3] | |
| return gr.update(value=os.path.join(BASE_DIR, img_path), visible=True), gr.update(value=img_path), gr.update( | |
| visible=True) | |
| def delete_record(table_data, img_path, username): | |
| if img_path is None: | |
| return gr.update(value=table_data) | |
| if os.path.exists(os.path.join(BASE_DIR, img_path)): | |
| os.remove(os.path.join(BASE_DIR, img_path)) | |
| json_path = os.path.join(JSON_DATA_DIR, f'{username}.json') | |
| if os.path.exists(json_path): | |
| with open(json_path, "r") as f: | |
| results = json.load(f) | |
| results = [r for r in results if r["image_path"] != img_path] | |
| with open(json_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| rows = [] | |
| for row in results: | |
| rows.append([row["timestamp"], | |
| row["predicted_id"], | |
| row["predicted_pokemon"], | |
| row["image_path"]]) | |
| return gr.update(value=rows, visible=True), gr.update(value=None, visible=False), gr.update(visible=False) | |
| def signin(user, psw): | |
| if not os.path.exists(JSON_USERS_PATH): | |
| return ( | |
| gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg | |
| gr.update(value='user', visible=False), # username_sb | |
| gr.update(visible=True), # login_pg | |
| gr.update(visible=False), # archive_sb_btn | |
| gr.update(visible=True), # sign_sb_btn | |
| gr.update(visible=False), # delete_sb_btn | |
| gr.update(visible=False), # save_btn | |
| gr.update(visible=False) # classifier_pg | |
| ) | |
| else: | |
| path = JSON_USERS_PATH | |
| with open(path, 'r') as f: | |
| users = json.load(f) | |
| for u in users: | |
| if u['username'].lower() == user.lower() and u['password'] == psw: | |
| return ( | |
| gr.update(value='Login effettuato', visible=True), # login_msg | |
| gr.update(value=f"**{user.upper()}**", visible=True), # username_sb | |
| gr.update(visible=False), # login_pg | |
| gr.update(visible=True), # archive_sb_btn | |
| gr.update(visible=False), # sign_sb_btn | |
| gr.update(visible=True), # delete_sb_btn | |
| gr.update(visible=True), # save_btn | |
| gr.update(visible=True) # classifier_pg | |
| ) | |
| return ( | |
| gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg | |
| gr.update(value='user', visible=False), # username_sb | |
| gr.update(visible=True), # login_pg | |
| gr.update(visible=False), # archive_sb_btn | |
| gr.update(visible=True), # sign_sb_btn | |
| gr.update(visible=False), # delete_sb_btn | |
| gr.update(visible=False), # save_btn | |
| gr.update(visible=False) # classifier_pg | |
| ) | |
| def signup(user, psw): | |
| try: | |
| with open(JSON_USERS_PATH, "r") as f: | |
| results = json.load(f) | |
| except FileNotFoundError: | |
| results = [] | |
| for u in results: | |
| if u['username'].lower() == user.lower(): | |
| return ( | |
| gr.update(value='Utente già esistente', visible=True), # login_msg | |
| gr.update(value='user', visible=False), # username_sb | |
| gr.update(visible=True), # login_pg | |
| gr.update(visible=False), # archive_sb_btn | |
| gr.update(visible=True), # sign_sb_btn | |
| gr.update(visible=False), # delete_sb_btn | |
| gr.update(visible=False), # save_btn | |
| gr.update(visible=False) # classifier_pg | |
| ) | |
| if user == '' or psw == '': | |
| return ( | |
| gr.update(value='Nome utente e/o password non validi', visible=True), # login_msg | |
| gr.update(value='user', visible=False), # username_sb | |
| gr.update(visible=True), # login_pg | |
| gr.update(visible=False), # archive_sb_btn | |
| gr.update(visible=True), # sign_sb_btn | |
| gr.update(visible=False), # delete_sb_btn | |
| gr.update(visible=False), # save_btn | |
| gr.update(visible=False) # classifier_pg | |
| ) | |
| results.append({ | |
| 'username': user, | |
| 'password': psw | |
| }) | |
| with open(JSON_USERS_PATH, "w") as f: | |
| json.dump(results, f, indent=2) | |
| return ( | |
| gr.update(value='Login effettuato', visible=True), # login_msg | |
| gr.update(value=f"**{user.upper()}**", visible=True), # username_sb | |
| gr.update(visible=False), # login_pg | |
| gr.update(visible=True), # archive_sb_btn | |
| gr.update(visible=False), # sign_sb_btn | |
| gr.update(visible=True), # delete_sb_btn | |
| gr.update(visible=True), # save_btn | |
| gr.update(visible=True) # classifier_pg | |
| ) | |
| from PIL import Image | |
| import gradio as gr | |
| from datetime import datetime | |
| import json | |
| css = """ | |
| #container {min-height: 100vh;} | |
| .sidebar { | |
| background-color: #f0f0f0; | |
| padding: 10px; | |
| height: 100vh; | |
| overflow-y: auto; | |
| } | |
| .content {padding: 20px; padding-bottom: 60px; overflow-y:auto} | |
| .fixed_image img{ | |
| max-width: 224px; | |
| height: 224px; /* altezza massima */ | |
| object-fit: contain; /* mantiene proporzioni */ | |
| } | |
| """ | |
| # UI | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Row(elem_id='container'): | |
| # sidebar | |
| with gr.Column(scale=1, min_width=200, elem_classes='sidebar'): | |
| username_sb = gr.Markdown('user', visible=False) | |
| cls_sb_bt = gr.Button('Classifier', visible=True) | |
| archive_sb_btn = gr.Button('Archive', visible=False) | |
| sign_sb_btn = gr.Button('Sign-in/Sign-up', visible=True) | |
| delete_sb_btn = gr.Button('Delete Account', visible=False) | |
| # content | |
| with gr.Column(scale=4): | |
| with gr.Column(scale=4, min_width=400, elem_classes="content"): | |
| # Pagina classificatore | |
| with gr.Group(visible=True) as classifier: | |
| gr.Markdown('### Classifier') | |
| with gr.Row(): | |
| uploader = gr.Image(label='Upload image', sources=['upload', 'webcam'], type='pil', | |
| elem_classes='fixed_image') | |
| with gr.Row(): | |
| output_id = gr.Number(value=0, label="ID", interactive=False) | |
| output_name = gr.Textbox(value="", label="Pokémon", interactive=False) | |
| with gr.Row(): | |
| # pulsanti | |
| cls_btn = gr.Button("Classify", visible=True) | |
| save_btn = gr.Button('Save Results', visible=False) | |
| save_status = gr.Textbox(label='Status', interactive=False, visible=False) | |
| # eventi | |
| cls_btn.click( | |
| fn=predicted_pokemon, | |
| inputs=uploader, | |
| outputs=[output_id, output_name] | |
| ) | |
| save_btn.click( | |
| fn=save_result, | |
| inputs=[uploader, output_id, output_name, username_sb], | |
| outputs=save_status | |
| ) | |
| # Pagina archivio | |
| with gr.Group(visible=False) as archive: | |
| gr.Markdown('### Archive') | |
| with gr.Row(): | |
| archive_msg = gr.Markdown() | |
| archive_df = gr.Dataframe( | |
| headers=['Timestamp', 'ID', 'Pokémon', 'Image Path'], | |
| datatype=['str', 'number', 'str', 'str'], | |
| interactive=False, | |
| visible=False | |
| ) | |
| selected_image = gr.Image(label='Immagine selezionata.', visible=False, type='filepath', | |
| interactive=False, elem_classes='fixed_image') | |
| image_path = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| # pulsanti | |
| load_btn = gr.Button('Refresh Archive') | |
| delete_btn = gr.Button('Delete Record', visible=False) | |
| # Eventi | |
| load_btn.click(fn=show_archive, inputs=[username_sb], outputs=[archive_msg, archive_df]) | |
| archive_df.select( | |
| fn=show_image, | |
| inputs=[], | |
| outputs=[selected_image, image_path, delete_btn] | |
| ) | |
| delete_btn.click( | |
| fn=delete_record, | |
| inputs=[archive_df, image_path, username_sb], | |
| outputs=[archive_df, selected_image, delete_btn] | |
| ) | |
| with gr.Group(visible=False) as login: | |
| gr.Markdown('### Login') | |
| with gr.Row(): | |
| username_tbox = gr.Textbox(label='Username') | |
| password_tbox = gr.Textbox(label='Password', type='password') | |
| with gr.Row(): | |
| login_msg = gr.Textbox(label='Status', interactive=False, visible=False) | |
| with gr.Row(): | |
| # pulsanti | |
| signin_btn = gr.Button('Sign-in') | |
| signup_btn = gr.Button('Sign-up') | |
| # logica eventi | |
| signin_btn.click( | |
| fn=signin, | |
| inputs=[username_tbox, password_tbox], | |
| outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn, | |
| classifier] | |
| ) | |
| signup_btn.click( | |
| fn=signup, | |
| inputs=[username_tbox, password_tbox], | |
| outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn, | |
| classifier] | |
| ) | |
| cls_sb_bt.click(fn=lambda username: show_page('classifier', username), inputs=[username_sb], | |
| outputs=[classifier, archive, login]) | |
| archive_sb_btn.click(fn=lambda username: show_page('archive', username), inputs=[username_sb], | |
| outputs=[classifier, archive, login, archive_msg, archive_df]) | |
| sign_sb_btn.click(fn=lambda username: show_page('login', username), inputs=[username_sb], | |
| outputs=[classifier, archive, login]) | |
| delete_sb_btn.click() | |
| if __name__ == '__main__': | |
| demo.launch() | |