Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from sqlalchemy import ( | |
| TIMESTAMP, | |
| Boolean, | |
| Column, | |
| ForeignKey, | |
| Integer, | |
| String, | |
| Text, | |
| create_engine, | |
| or_, | |
| ) | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import Mapped, relationship, sessionmaker | |
| from sqlalchemy.sql import func | |
| from datasets import load_dataset | |
| ds = load_dataset("bilguun/flickr30k-mn") | |
| load_dotenv() | |
| DATABASE_URL = os.getenv("DATABASE_URL") | |
| engine = create_engine(DATABASE_URL) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| # images_captions model | |
| class ImagesCaptions(Base): | |
| __tablename__ = "images_captions" | |
| id = Column(Integer, primary_key=True, index=True) | |
| image_id = Column(Integer) | |
| image_name = Column(String) | |
| caption_num = Column(Integer) | |
| caption = Column(Text) | |
| caption_mn_v1 = Column(Text) | |
| caption_mn_v2 = Column(Text) | |
| # task model | |
| class Task(Base): | |
| __tablename__ = "task" | |
| id = Column(Integer, primary_key=True, index=True) | |
| image_caption_id = Column(Integer, ForeignKey("images_captions.id")) | |
| caption_num = Column(Integer) | |
| reverse_caption = Column(Boolean) | |
| status = Column(String) | |
| image_caption: Mapped[ImagesCaptions] = relationship("ImagesCaptions") | |
| # task_submission model | |
| class TaskSubmission(Base): | |
| __tablename__ = "task_submission" | |
| id = Column(Integer, primary_key=True, index=True) | |
| task_id = Column(Integer, ForeignKey("task.id")) | |
| choice = Column(Text) | |
| created_by = Column(String) | |
| created_at = Column(TIMESTAMP, server_default=func.now()) | |
| task: Mapped[Task] = relationship("Task") | |
| def get_random_task() -> Task | None: | |
| """Retrieves a random task from the top 100 pending or in_progress tasks.""" | |
| db = SessionLocal() | |
| try: | |
| tasks = ( | |
| db.query(Task) | |
| .filter(or_(Task.status == "pending", Task.status == "in_progress")) | |
| .where(Task.image_caption_id % random.randint(1, 5) == 0) | |
| .order_by(Task.image_caption_id.asc()) | |
| .limit(500) | |
| .all() | |
| ) | |
| if tasks: | |
| random_task = random.choice(tasks) | |
| return random_task | |
| else: | |
| return None | |
| finally: | |
| db.close() | |
| def random_task(): | |
| task = get_random_task() | |
| if task is None: | |
| return None, None, None, None | |
| with SessionLocal() as db: | |
| task = db.query(Task).filter(Task.id == task.id).first() | |
| if task is None: | |
| return None, None, None, None | |
| caption1 = str(task.image_caption.caption_mn_v1) | |
| caption2 = str(task.image_caption.caption_mn_v2) | |
| if task.reverse_caption: | |
| caption1, caption2 = caption2, caption1 | |
| return ( | |
| ds["train"][task.image_caption.image_id]["image"], | |
| # str(task.image_caption.caption), | |
| caption1, | |
| caption2, | |
| int(task.id), | |
| ) | |
| css = """ | |
| .caption-btn { | |
| background: #fcdccc; | |
| border: 2px solid #f09162; | |
| } | |
| .dark .caption-btn { | |
| background: #26201f; | |
| border: 2px solid #40271a; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as blind_test: | |
| username = gr.Textbox( | |
| label="Нэрээ оруулна уу", placeholder="Нэр", max_lines=1, max_length=40 | |
| ) | |
| local_storage = gr.BrowserState([""]) | |
| def load_from_local_storage(saved_values): | |
| print("loading from local storage", saved_values) | |
| return saved_values[0] | |
| def save_to_local_storage(username): | |
| return [username] | |
| task_id = gr.State(None) | |
| image, desc, choice1, choice2 = None, None, None, None | |
| img_preview = gr.Image( | |
| image, label="Зураг", show_label=True, show_download_button=False, height=400 | |
| ) | |
| md_desc = gr.Markdown( | |
| "### Доорх хоёр тайлбараас зурагтай хамгийн сайн тохирч буйг сонгоно уу." | |
| ) | |
| with gr.Row(equal_height=True, variant="panel"): | |
| with gr.Column(scale=1): | |
| caption_choice1_button = gr.Button( | |
| choice1, variant="secondary", elem_classes="caption-btn" | |
| ) | |
| with gr.Column(scale=1): | |
| caption_choice2_button = gr.Button( | |
| choice2, variant="secondary", elem_classes="caption-btn" | |
| ) | |
| def on_submit(username: str, choice: int, task_id: int): | |
| print("on_submit", username if username is not None else None, choice, task_id) | |
| if username == "": | |
| gr.Warning("Нэрээ оруулна уу!") | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| if choice not in [1, 2]: | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| ) | |
| with SessionLocal() as db: | |
| task = db.query(Task).filter(Task.id == task_id).first() | |
| if task is None: | |
| return None, None, None, None | |
| if task.reverse_caption: | |
| choice = 2 if choice == 1 else 1 | |
| task_submission = TaskSubmission( | |
| task_id=task.id, choice=choice, created_by=username | |
| ) | |
| db.add(task_submission) | |
| db.commit() | |
| submission_count = ( | |
| db.query(TaskSubmission) | |
| .filter(TaskSubmission.task_id == task.id) | |
| .count() | |
| ) | |
| if submission_count >= 3: | |
| task.status = "done" | |
| elif submission_count >= 0: | |
| task.status = "in_progress" | |
| db.commit() | |
| image, choice1, choice2, task_id = random_task() | |
| return image, choice1, choice2, task_id | |
| def submit_choice1(username, task_id): | |
| return on_submit(username, 1, task_id) | |
| def submit_choice2(username, task_id): | |
| return on_submit(username, 2, task_id) | |
| blind_test.load( | |
| fn=random_task, | |
| outputs=[img_preview, caption_choice1_button, caption_choice2_button, task_id], | |
| ) | |
| blind_test.launch() | |