DeadfoxX commited on
Commit
204945f
·
1 Parent(s): 1afc11f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -12
app.py CHANGED
@@ -1,16 +1,149 @@
1
- import torch
2
- import transformers
3
 
4
- # Load the checkpoint file into a PyTorch model
5
- model = YourModelClass()
6
- state_dict = torch.load('souleater-diffusion.ckpt')
7
- model.load_state_dict(state_dict)
8
 
9
- # Save the model architecture and weights to a file
10
- torch.save(model.state_dict(), 'souleater-diffusion.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Register the model with Hugging Face
13
- model_name = "DeadfoxX/souleater-diffusion"
14
- model_id = transformers.Model.upload(model_name, "souleater-diffusion.pth")
15
 
16
- print(f"Model uploaded with ID: {model_id}")
 
 
 
1
 
 
 
 
 
2
 
3
+ import gradio as gr
4
+ import sys
5
+ import random
6
+ import paddlehub as hub
7
+
8
+ language_translation_model = hub.Module(directory=f'./baidu_translate')
9
+ def getTextTrans(text, source='zh', target='en'):
10
+ try:
11
+ text_translation = language_translation_model.translate(text, source, target)
12
+ return text_translation
13
+ except Exception as e:
14
+ return text
15
+
16
+
17
+ model_ids = {
18
+ "models/DeadfoxX/Souleater_Diffusion":"sd-v1-0",
19
+
20
+ }
21
+ tab_actions = []
22
+ tab_titles = []
23
+ for model_id in model_ids.keys():
24
+ print(model_id, model_ids[model_id])
25
+ try:
26
+ tab = gr.Interface.load(model_id)
27
+ tab_actions.append(tab)
28
+ tab_titles.append(model_ids[model_id])
29
+ except:
30
+ pass
31
+
32
+ def infer(prompt):
33
+ prompt = getTextTrans(prompt, source='zh', target='en') + f',{random.randint(0,sys.maxsize)}'
34
+ return prompt
35
+
36
+ start_work = """async() => {
37
+ function isMobile() {
38
+ try {
39
+ document.createEvent("TouchEvent"); return true;
40
+ } catch(e) {
41
+ return false;
42
+ }
43
+ }
44
+ function getClientHeight()
45
+ {
46
+ var clientHeight=0;
47
+ if(document.body.clientHeight&&document.documentElement.clientHeight) {
48
+ var clientHeight = (document.body.clientHeight<document.documentElement.clientHeight)?document.body.clientHeight:document.documentElement.clientHeight;
49
+ } else {
50
+ var clientHeight = (document.body.clientHeight>document.documentElement.clientHeight)?document.body.clientHeight:document.documentElement.clientHeight;
51
+ }
52
+ return clientHeight;
53
+ }
54
+
55
+ function setNativeValue(element, value) {
56
+ const valueSetter = Object.getOwnPropertyDescriptor(element.__proto__, 'value').set;
57
+ const prototype = Object.getPrototypeOf(element);
58
+ const prototypeValueSetter = Object.getOwnPropertyDescriptor(prototype, 'value').set;
59
+
60
+ if (valueSetter && valueSetter !== prototypeValueSetter) {
61
+ prototypeValueSetter.call(element, value);
62
+ } else {
63
+ valueSetter.call(element, value);
64
+ }
65
+ }
66
+ var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
67
+ if (!gradioEl) {
68
+ gradioEl = document.querySelector('body > gradio-app');
69
+ }
70
+
71
+ if (typeof window['gradioEl'] === 'undefined') {
72
+ window['gradioEl'] = gradioEl;
73
+
74
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
75
+ for (var i = 0; i < tabitems.length; i++) {
76
+ tabitems[i].childNodes[0].children[0].style.display='none';
77
+ tabitems[i].childNodes[0].children[1].children[0].style.display='none';
78
+ tabitems[i].childNodes[0].children[1].children[1].children[0].children[1].style.display="none";
79
+ }
80
+ tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
81
+ tab_demo.style.display = "block";
82
+ tab_demo.setAttribute('style', 'height: 100%;');
83
+ const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
84
+ const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
85
+
86
+ page1.style.display = "none";
87
+ page2.style.display = "block";
88
+ window['prevPrompt'] = '';
89
+ window['doCheckPrompt'] = 0;
90
+ window['checkPrompt'] = function checkPrompt() {
91
+ try {
92
+ texts = window['gradioEl'].querySelectorAll('textarea');
93
+ text0 = texts[0];
94
+ text1 = texts[1];
95
+ if (window['doCheckPrompt'] == 0 && window['prevPrompt'] !== text1.value) {
96
+ window['doCheckPrompt'] = 1;
97
+ window['prevPrompt'] = text1.value;
98
+ for (var i = 2; i < texts.length; i++) {
99
+ setNativeValue(texts[i], text1.value);
100
+ texts[i].dispatchEvent(new Event('input', { bubbles: true }));
101
+ }
102
+ setTimeout(function() {
103
+ //text1 = window['gradioEl'].querySelectorAll('textarea')[1];
104
+
105
+ btns = window['gradioEl'].querySelectorAll('button');
106
+ for (var i = 0; i < btns.length; i++) {
107
+ if (btns[i].innerText == 'Submit') {
108
+ //btns[i].focus();
109
+ btns[i].click();
110
+ }
111
+ }
112
+ window['doCheckPrompt'] = 0;
113
+ }, 10);
114
+ }
115
+ } catch(e) {
116
+ }
117
+ }
118
+ window['checkPrompt_interval'] = window.setInterval("window.checkPrompt()", 100);
119
+ }
120
+
121
+ return false;
122
+ }"""
123
+
124
+ with gr.Blocks(title='Text to Image') as demo:
125
+ with gr.Group(elem_id="page_1", visible=True) as page_1:
126
+ with gr.Box():
127
+ with gr.Row():
128
+ start_button = gr.Button("Let's GO!", elem_id="start-btn", visible=True)
129
+ start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
130
+
131
+ with gr.Group(elem_id="page_2", visible=False) as page_2:
132
+ with gr.Row(elem_id="prompt_row"):
133
+ prompt_input0 = gr.Textbox(lines=4, label="prompt")
134
+ prompt_input1 = gr.Textbox(lines=4, label="prompt", visible=True)
135
+ with gr.Row():
136
+ submit_btn = gr.Button(value = "submit",elem_id="erase-btn").style(
137
+ margin=True,
138
+ rounded=(True, True, True, True),
139
+ )
140
+ with gr.Row(elem_id='tab_demo', visible=True).style(height=5):
141
+ tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
142
+
143
+ submit_btn.click(fn=infer, inputs=[prompt_input0], outputs=[prompt_input1])
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch()
147
+
148
 
 
 
 
149