DemoPokemon / app.py
Naahbi's picture
Upload 2 files
d0ea483 verified
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
# Loading in Model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ViTForImageClassification.from_pretrained("imjeffhi/pokemon_classifier").to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('imjeffhi/pokemon_classifier')
def predicted_pokemon(img):
extracted = feature_extractor(images=img, return_tensors='pt').to(device)
predicted_id = model(**extracted).logits.argmax(-1).item()
predicted_pokemon = model.config.id2label[predicted_id]
return predicted_id, predicted_pokemon
"""## Platform code
"""
import os
import gradio as gr
import json
BASE_DIR = 'my_project'
os.makedirs(BASE_DIR, exist_ok=True)
IMG_DIR = os.path.join(BASE_DIR, 'images')
os.makedirs(IMG_DIR, exist_ok=True)
JSON_DATA_DIR = os.path.join(BASE_DIR, 'results')
os.makedirs(JSON_DATA_DIR, exist_ok=True)
JSON_USERS_PATH = os.path.join(BASE_DIR, 'users.json')
"""### main codes"""
def show_page(page_name, username):
if page_name == 'classifier':
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif page_name == 'archive':
archive_msg, archive_df = show_archive(username)
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), archive_msg, archive_df
elif page_name == 'login':
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def save_result(img: Image.Image, predicted_id, predicted_pokemon, username):
# controlli di sicurezza
if img is None:
return "Error: no image uploaded!"
if predicted_id is None or predicted_id == "":
return "Error: classify first!"
if predicted_pokemon is None or predicted_pokemon == "":
return "Error: classify first!"
if username == 'user' or username == '':
return "Error: sign-in first"
now = datetime.now()
timestamp_file = now.strftime('%Y-%m-%d-%H-%M-%S')
timestamp_json = now.isoformat()
# Salva l'immagine
img_path = os.path.join(IMG_DIR, username, f'{timestamp_file}.jpg')
os.makedirs(os.path.dirname(img_path), exist_ok=True)
img.save(img_path)
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
# Leggi JSON esistente
try:
with open(json_path, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = []
results.append({
'timestamp': timestamp_json,
'image_path': os.path.relpath(img_path, BASE_DIR),
'predicted_id': int(predicted_id),
'predicted_pokemon': predicted_pokemon
})
with open(json_path, "w") as f:
json.dump(results, f, indent=2)
return gr.update(value="Results saved successfully!", visible=True)
# laod dell'archivio
def show_archive(username):
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
if not os.path.exists(json_path):
return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False)
with open(json_path, "r") as f:
archive_data = json.load(f)
if not archive_data:
return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False)
rows = []
for row in archive_data:
rows.append([row["timestamp"],
row["predicted_id"],
row["predicted_pokemon"],
row["image_path"]])
return gr.update(value="", visible=False), gr.update(value=rows, visible=True)
def show_image(evt: gr.SelectData):
img_path = evt.row_value[3]
return gr.update(value=os.path.join(BASE_DIR, img_path), visible=True), gr.update(value=img_path), gr.update(
visible=True)
def delete_record(table_data, img_path, username):
if img_path is None:
return gr.update(value=table_data)
if os.path.exists(os.path.join(BASE_DIR, img_path)):
os.remove(os.path.join(BASE_DIR, img_path))
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
if os.path.exists(json_path):
with open(json_path, "r") as f:
results = json.load(f)
results = [r for r in results if r["image_path"] != img_path]
with open(json_path, "w") as f:
json.dump(results, f, indent=2)
rows = []
for row in results:
rows.append([row["timestamp"],
row["predicted_id"],
row["predicted_pokemon"],
row["image_path"]])
return gr.update(value=rows, visible=True), gr.update(value=None, visible=False), gr.update(visible=False)
def signin(user, psw):
if not os.path.exists(JSON_USERS_PATH):
return (
gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
else:
path = JSON_USERS_PATH
with open(path, 'r') as f:
users = json.load(f)
for u in users:
if u['username'].lower() == user.lower() and u['password'] == psw:
return (
gr.update(value='Login effettuato', visible=True), # login_msg
gr.update(value=f"**{user.upper()}**", visible=True), # username_sb
gr.update(visible=False), # login_pg
gr.update(visible=True), # archive_sb_btn
gr.update(visible=False), # sign_sb_btn
gr.update(visible=True), # delete_sb_btn
gr.update(visible=True), # save_btn
gr.update(visible=True) # classifier_pg
)
return (
gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
def signup(user, psw):
try:
with open(JSON_USERS_PATH, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = []
for u in results:
if u['username'].lower() == user.lower():
return (
gr.update(value='Utente già esistente', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
if user == '' or psw == '':
return (
gr.update(value='Nome utente e/o password non validi', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
results.append({
'username': user,
'password': psw
})
with open(JSON_USERS_PATH, "w") as f:
json.dump(results, f, indent=2)
return (
gr.update(value='Login effettuato', visible=True), # login_msg
gr.update(value=f"**{user.upper()}**", visible=True), # username_sb
gr.update(visible=False), # login_pg
gr.update(visible=True), # archive_sb_btn
gr.update(visible=False), # sign_sb_btn
gr.update(visible=True), # delete_sb_btn
gr.update(visible=True), # save_btn
gr.update(visible=True) # classifier_pg
)
from PIL import Image
import gradio as gr
from datetime import datetime
import json
css = """
#container {min-height: 100vh;}
.sidebar {
background-color: #f0f0f0;
padding: 10px;
height: 100vh;
overflow-y: auto;
}
.content {padding: 20px; padding-bottom: 60px; overflow-y:auto}
.fixed_image img{
max-width: 224px;
height: 224px; /* altezza massima */
object-fit: contain; /* mantiene proporzioni */
}
"""
# UI
with gr.Blocks(css=css) as demo:
with gr.Row(elem_id='container'):
# sidebar
with gr.Column(scale=1, min_width=200, elem_classes='sidebar'):
username_sb = gr.Markdown('user', visible=False)
cls_sb_bt = gr.Button('Classifier', visible=True)
archive_sb_btn = gr.Button('Archive', visible=False)
sign_sb_btn = gr.Button('Sign-in/Sign-up', visible=True)
delete_sb_btn = gr.Button('Delete Account', visible=False)
# content
with gr.Column(scale=4):
with gr.Column(scale=4, min_width=400, elem_classes="content"):
# Pagina classificatore
with gr.Group(visible=True) as classifier:
gr.Markdown('### Classifier')
with gr.Row():
uploader = gr.Image(label='Upload image', sources=['upload', 'webcam'], type='pil',
elem_classes='fixed_image')
with gr.Row():
output_id = gr.Number(value=0, label="ID", interactive=False)
output_name = gr.Textbox(value="", label="Pokémon", interactive=False)
with gr.Row():
# pulsanti
cls_btn = gr.Button("Classify", visible=True)
save_btn = gr.Button('Save Results', visible=False)
save_status = gr.Textbox(label='Status', interactive=False, visible=False)
# eventi
cls_btn.click(
fn=predicted_pokemon,
inputs=uploader,
outputs=[output_id, output_name]
)
save_btn.click(
fn=save_result,
inputs=[uploader, output_id, output_name, username_sb],
outputs=save_status
)
# Pagina archivio
with gr.Group(visible=False) as archive:
gr.Markdown('### Archive')
with gr.Row():
archive_msg = gr.Markdown()
archive_df = gr.Dataframe(
headers=['Timestamp', 'ID', 'Pokémon', 'Image Path'],
datatype=['str', 'number', 'str', 'str'],
interactive=False,
visible=False
)
selected_image = gr.Image(label='Immagine selezionata.', visible=False, type='filepath',
interactive=False, elem_classes='fixed_image')
image_path = gr.Textbox(visible=False)
with gr.Row():
# pulsanti
load_btn = gr.Button('Refresh Archive')
delete_btn = gr.Button('Delete Record', visible=False)
# Eventi
load_btn.click(fn=show_archive, inputs=[username_sb], outputs=[archive_msg, archive_df])
archive_df.select(
fn=show_image,
inputs=[],
outputs=[selected_image, image_path, delete_btn]
)
delete_btn.click(
fn=delete_record,
inputs=[archive_df, image_path, username_sb],
outputs=[archive_df, selected_image, delete_btn]
)
with gr.Group(visible=False) as login:
gr.Markdown('### Login')
with gr.Row():
username_tbox = gr.Textbox(label='Username')
password_tbox = gr.Textbox(label='Password', type='password')
with gr.Row():
login_msg = gr.Textbox(label='Status', interactive=False, visible=False)
with gr.Row():
# pulsanti
signin_btn = gr.Button('Sign-in')
signup_btn = gr.Button('Sign-up')
# logica eventi
signin_btn.click(
fn=signin,
inputs=[username_tbox, password_tbox],
outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn,
classifier]
)
signup_btn.click(
fn=signup,
inputs=[username_tbox, password_tbox],
outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn,
classifier]
)
cls_sb_bt.click(fn=lambda username: show_page('classifier', username), inputs=[username_sb],
outputs=[classifier, archive, login])
archive_sb_btn.click(fn=lambda username: show_page('archive', username), inputs=[username_sb],
outputs=[classifier, archive, login, archive_msg, archive_df])
sign_sb_btn.click(fn=lambda username: show_page('login', username), inputs=[username_sb],
outputs=[classifier, archive, login])
delete_sb_btn.click()
if __name__ == '__main__':
demo.launch()