File size: 3,622 Bytes
8683d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38f30e6
8683d51
 
 
 
 
 
 
 
 
cab2c2a
8683d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import pandas as pd
import networkx as nx
import os
import pickle
import tqdm
from analysis import build_graph, parse_page

def clean(results):
    new = {}
    for k in results:
        if results[k] and len(results[k]) > 0:
            new[k] = results[k]
    return new

# Your existing functions here...
if "B_degree_threshold" not in st.session_state:
    st.session_state.B_degree_threshold = 10
if "B" not in st.session_state:
    if not os.path.exists('data.pkl'):
        page_folder = 'pages'
        pages = os.listdir(page_folder)
        results = {}
        
        for p in tqdm.tqdm(pages):
            try:
                results[p] = parse_page(os.path.join(page_folder, p))
            except Exception as e:
                pass
        
        with open('data.pkl', 'wb') as f:
            pickle.dump(results, f)

    else:
        with open('data.pkl', 'rb') as f:
            results = pickle.load(f)
    
    st.session_state.results = clean(results)
    st.session_state.B = build_graph(st.session_state.results, st.session_state.B_degree_threshold)



# Streamlit app
def main():
    st.title("SD BaseModel Lora Connections")

    # Sidebar for degree_threshold
    B_degree_threshold = st.sidebar.slider("Select Degree Threshold", 1, 100, 10)

    # Build the graph
    if B_degree_threshold != st.session_state.B_degree_threshold:
        st.session_state.B_degree_threshold = B_degree_threshold
        st.session_state.B = build_graph(st.session_state.results, B_degree_threshold)

    st.sidebar.write(f"There are {len(st.session_state.B)} nodes analyzed.")

    # Filter out model nodes and lora nodes
    model_nodes = {n for n, d in st.session_state.B.nodes(data=True) if d['bipartite']==0}
    lora_nodes = set(st.session_state.B) - model_nodes

    # Sort model nodes and lora nodes based on their degree
    sorted_models = sorted(model_nodes, key=lambda x: st.session_state.B.degree(x), reverse=True)
    sorted_loras = sorted(lora_nodes, key=lambda x: st.session_state.B.degree(x), reverse=True)

    # Model selection
    selected_model = st.selectbox("Select Model (sorted by degree)", sorted_models)
    if selected_model:
        loras_for_model = list(st.session_state.B.neighbors(selected_model))
        page_names_for_model = [st.session_state.B[selected_model][lora]['page'] for lora in loras_for_model]
        page_names_for_model = ['https://civitai.com/images/'+page for page in page_names_for_model]
        
        # Convert DataFrame to HTML with clickable links
        df = pd.DataFrame({"Lora Names": loras_for_model, "Image Link": page_names_for_model})
        df["Image Link"] = df["Image Link"].apply(lambda x: f'<a href="{x}" target="_blank">{x}</a>')
        st.markdown(df.to_html(escape=False, index=False), unsafe_allow_html=True)

    # Lora selection
    selected_lora = st.selectbox("Select Lora (sorted by degree)", sorted_loras)
    if selected_lora:
        models_for_lora = list(st.session_state.B.neighbors(selected_lora))
        page_names_for_lora = [st.session_state.B[model][selected_lora]['page'] for model in models_for_lora]
        page_names_for_lora = ['https://civitai.com/images/'+page for page in page_names_for_lora]
        
        # Convert DataFrame to HTML with clickable links
        df = pd.DataFrame({"Model Names": models_for_lora, "Image Link": page_names_for_lora})
        df["Image Link"] = df["Image Link"].apply(lambda x: f'<a href="{x}" target="_blank">{x}</a>')
        st.markdown(df.to_html(escape=False, index=False), unsafe_allow_html=True)


if __name__ == "__main__":
    main()