| import gradio as gr |
| from src.llm import query_chroma |
| from src.reranker_warm import rank_anime_warm |
| import pandas as pd |
| from pathlib import Path |
| import requests |
|
|
| css = """ |
| footer {display: none !important} |
| .gradio-container { |
| max-width: 1200px; |
| margin: auto; |
| } |
| .contain { |
| background: rgba(255, 255, 255, 0.05); |
| border-radius: 12px; |
| padding: 20px; |
| } |
| .submit-btn { |
| background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; |
| border: none !important; |
| color: white !important; |
| } |
| .submit-btn:hover { |
| transform: translateY(-2px); |
| box-shadow: 0 5px 15px rgba(0,0,0,0.2); |
| } |
| .title { |
| text-align: center; |
| font-size: 2.5em; |
| font-weight: bold; |
| margin-bottom: 1em; |
| background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| } |
| .output-image { |
| width: 100% !important; |
| max-width: 100% !important; |
| } |
| """ |
|
|
|
|
| def download_pic(names: list[str]): |
| df = pd.read_csv(str(Path(Path(__file__).parent, "src/data/final_anime_list.csv"))) |
| pic_dir = Path(Path(__file__).parent, "src/data/pics") |
|
|
| if not pic_dir.exists(): |
| pic_dir.mkdir(exist_ok=True) |
|
|
| df = df[df.Name.isin(names)] |
| df = df[["Name", "Image URL", "Synopsis"]].set_index("Name").reindex(names) |
| synopsis_list = df['Synopsis'].tolist() |
| file_paths = [] |
| for url in df["Image URL"]: |
| file_name = "_".join(url.split("/")[-2:]) |
| file_path = Path(pic_dir, file_name) |
|
|
| if file_path.exists(): |
| file_paths.append(str(file_path)) |
| continue |
|
|
| response = requests.get(url) |
| response.raise_for_status() |
|
|
| with open(Path(pic_dir, file_name), 'wb') as file: |
| file.write(response.content) |
| file_paths.append(str(file_path)) |
|
|
| print(f"Image downloaded successfully: {url}") |
| return file_paths, synopsis_list |
| |
|
|
|
|
| def integration_warm(query: str): |
| anime_name_list = query_chroma(query=query, anime_count=100) |
|
|
| |
| |
| |
| |
| anime_name_list = rank_anime_warm(userid=12, anime_list=anime_name_list)[:4] |
| final_names = [x[0] for x in anime_name_list] |
|
|
| anime_pic_list, synopsis_list = download_pic(list(final_names)) |
|
|
| return [*anime_name_list, *anime_pic_list, *synopsis_list] |
|
|
|
|
|
|
|
|
| def clear_prompt(): |
| """Function to clear the prompt box.""" |
| return "" |
|
|
|
|
| def feedback_button(action, anime_name): |
| |
| return f"You {action}d {anime_name}!" |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: |
| gr.HTML('<div class="title">AniQuest</div>') |
| gr.HTML( |
| '<div style="text-align: center; margin-bottom: 2em; color: #666; font-size: 24px;">We recommendate animes based on your description</div>') |
| gr.HTML(""" |
| <div style="color: red; margin-bottom: 1em; text-align: center; padding: 10px; background: rgba(255,0,0,0.1); border-radius: 8px;"> |
| β οΈ Welcome, [user_id: 12] to this recommendation system β οΈ |
| </div> |
| """) |
|
|
| with gr.Column(): |
| prompt = gr.Textbox( |
| label="Query", |
| placeholder="Describe the anime you want to watch next ...", |
| lines=1 |
| ) |
| with gr.Row(): |
| generate_btn = gr.Button( |
| "π Submit", |
| elem_classes=["submit-btn"] |
| ) |
| clear_btn = gr.Button( |
| "π
Clear", |
| elem_classes=["submit-btn"] |
| ) |
| with gr.Row(): |
| for i in range(4): |
|
|
| anime_names = [] |
| feedback_texts = [] |
|
|
| with gr.Column(scale=1, elem_classes=["anime-block"]): |
|
|
| exec(f"anime{i + 1} = gr.Textbox(label='Anime {i + 1}')") |
|
|
| with gr.Row(): |
| like_btn = gr.Button("π Like") |
| dislike_btn = gr.Button("π Dislike") |
|
|
| exec(f"image{i + 1} = gr.Image(label='Image', elem_classes=['output-image', 'fixed-width'])") |
| exec( |
| f"description{i + 1} = gr.HTML('<div class=\"anime-description\" style=\"margin-top: 10px; font-size: 14px; color: #666;\">Description for anime {i + 1}</div>')") |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| generate_btn.click( |
| fn=integration_warm, |
| inputs=[prompt], |
| outputs=[anime1, anime2, anime3, anime4, image1, image2, image3, image4, description1, description2, description3, description4, ] |
| ) |
| |
| clear_btn.click( |
| fn=clear_prompt, |
| inputs=[], |
| outputs=[prompt] |
| ) |
|
|
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|