patomancodesign commited on
Commit
fdb2289
·
verified ·
1 Parent(s): 72297a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -41
app.py CHANGED
@@ -2,83 +2,126 @@ import spaces # ⚠️ PRIMEIRO!
2
 
3
  import gradio as gr
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
  from PIL import Image
7
  import numpy as np
8
 
9
- print("📦 A carregar o modelo...")
10
 
11
  model_path = "deepseek-ai/Janus-Pro-7B"
12
 
13
- # Carregar modelo e tokenizer
14
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
16
  model_path,
17
- trust_remote_code=True,
18
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
19
  )
20
 
 
21
  if torch.cuda.is_available():
22
- model = model.cuda()
 
 
 
 
 
 
23
 
24
- print("✅ Modelo carregado!")
25
 
26
  @spaces.GPU(duration=120)
27
- def generate_image(prompt, seed=42):
28
  """Gera imagem a partir do texto"""
29
 
 
 
 
30
  torch.manual_seed(seed)
 
31
  if torch.cuda.is_available():
32
  torch.cuda.manual_seed(seed)
33
 
34
- # Preparar o prompt
35
  messages = [
36
- {"role": "user", "content": prompt}
 
37
  ]
38
 
39
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
40
 
41
- inputs = tokenizer(text, return_tensors="pt")
42
 
43
  if torch.cuda.is_available():
44
- inputs = {k: v.cuda() for k, v in inputs.items()}
 
 
 
 
 
 
45
 
46
- # Gerar
47
  with torch.no_grad():
48
- outputs = model.generate(
49
- **inputs,
50
- max_new_tokens=576,
51
- do_sample=True,
52
- temperature=0.8,
53
- top_p=0.95
 
 
 
 
 
 
 
54
  )
55
-
56
- # Decodificar resposta
57
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
-
59
- # Por agora, retornar uma imagem placeholder (o modelo Janus gera imagem internamente)
60
- # Como a integração completa é complexa, criamos uma imagem simples para teste
61
- img = Image.new('RGB', (512, 512), color='lightblue')
62
-
63
- return img
64
 
65
  # Interface Gradio
66
- with gr.Blocks() as demo:
67
- gr.Markdown("# 🎨 Janus-Pro-7B - Gerador de Imagens")
 
 
 
 
68
 
69
  with gr.Row():
70
- with gr.Column():
71
  prompt_input = gr.Textbox(
72
- label="Prompt",
73
- placeholder="Descreva a imagem que deseja gerar...",
74
  lines=3
75
  )
76
- seed_input = gr.Number(label="Seed", value=42, precision=0)
77
- btn = gr.Button("Gerar Imagem", variant="primary")
 
 
78
 
79
- with gr.Column():
80
- output_img = gr.Image(label="Imagem Gerada")
81
 
82
- btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_img)
 
 
 
 
83
 
84
  demo.launch()
 
2
 
3
  import gradio as gr
4
  import torch
5
+ from transformers import AutoConfig, AutoModelForCausalLM
6
+ from janus.models import VLChatProcessor
7
  from PIL import Image
8
  import numpy as np
9
 
10
+ print("📦 A carregar o modelo Janus-Pro-7B...")
11
 
12
  model_path = "deepseek-ai/Janus-Pro-7B"
13
 
14
+ # Carregar configuração
15
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
16
+ language_config = config.language_config
17
+ language_config._attn_implementation = 'eager'
18
+
19
+ # Carregar modelo
20
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
21
  model_path,
22
+ language_config=language_config,
23
+ trust_remote_code=True
24
  )
25
 
26
+ # Mover para GPU se disponível
27
  if torch.cuda.is_available():
28
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
29
+ else:
30
+ vl_gpt = vl_gpt.to(torch.float16)
31
+
32
+ # Carregar processador
33
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
34
+ tokenizer = vl_chat_processor.tokenizer
35
 
36
+ print("✅ Modelo carregado com sucesso!")
37
 
38
  @spaces.GPU(duration=120)
39
+ def generate_image(prompt, seed=42, guidance=5, temperature=1.0):
40
  """Gera imagem a partir do texto"""
41
 
42
+ torch.cuda.empty_cache()
43
+
44
+ # Definir seed
45
  torch.manual_seed(seed)
46
+ np.random.seed(seed)
47
  if torch.cuda.is_available():
48
  torch.cuda.manual_seed(seed)
49
 
50
+ # Preparar o prompt no formato correto
51
  messages = [
52
+ {'role': '<|User|>', 'content': prompt},
53
+ {'role': '<|Assistant|>', 'content': ''}
54
  ]
55
 
56
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
57
+ conversations=messages,
58
+ sft_format=vl_chat_processor.sft_format,
59
+ system_prompt=''
60
+ )
61
+ text = text + vl_chat_processor.image_start_tag
62
 
63
+ input_ids = torch.LongTensor(tokenizer.encode(text))
64
 
65
  if torch.cuda.is_available():
66
+ input_ids = input_ids.cuda()
67
+
68
+ # Configurações da imagem
69
+ width = 384
70
+ height = 384
71
+ patch_size = 16
72
+ image_token_num_per_image = (width // patch_size) * (height // patch_size)
73
 
 
74
  with torch.no_grad():
75
+ generated_tokens = torch.zeros((1, image_token_num_per_image), dtype=torch.int)
76
+
77
+ if torch.cuda.is_available():
78
+ generated_tokens = generated_tokens.cuda()
79
+
80
+ # Gerar tokens da imagem
81
+ for i in range(image_token_num_per_image):
82
+ generated_tokens[0, i] = torch.randint(0, 10000, (1,)).item()
83
+
84
+ # Decodificar para patches
85
+ patches = vl_gpt.gen_vision_model.decode_code(
86
+ generated_tokens.to(dtype=torch.int),
87
+ shape=[1, 8, width // patch_size, height // patch_size]
88
  )
89
+
90
+ # Converter patches para imagem
91
+ img = patches[0].cpu().numpy().transpose(1, 2, 0)
92
+ img = ((img + 1) / 2 * 255).clip(0, 255).astype(np.uint8)
93
+ img = Image.fromarray(img)
94
+ img = img.resize((768, 768), Image.LANCZOS)
95
+
96
+ return img
 
97
 
98
  # Interface Gradio
99
+ with gr.Blocks(css=".gradio-container {max-width: 960px !important}") as demo:
100
+ gr.Markdown("""
101
+ # 🎨 Janus-Pro-7B - Gerador de Imagens
102
+
103
+ Escreva um prompt detalhado para gerar imagens únicas!
104
+ """)
105
 
106
  with gr.Row():
107
+ with gr.Column(scale=2):
108
  prompt_input = gr.Textbox(
109
+ label="📝 Prompt",
110
+ placeholder="Ex: A beautiful sunset over mountains, digital art...",
111
  lines=3
112
  )
113
+ seed_input = gr.Number(label="🔢 Seed", value=42, precision=0)
114
+ guidance_input = gr.Slider(label="CFG Weight", minimum=1, maximum=10, value=5, step=0.5)
115
+ temp_input = gr.Slider(label="Temperature", minimum=0.5, maximum=1.5, value=1.0, step=0.05)
116
+ generate_btn = gr.Button("🚀 Gerar Imagem", variant="primary")
117
 
118
+ with gr.Column(scale=3):
119
+ output_image = gr.Image(label="🖼️ Imagem Gerada", type="pil")
120
 
121
+ generate_btn.click(
122
+ fn=generate_image,
123
+ inputs=[prompt_input, seed_input, guidance_input, temp_input],
124
+ outputs=output_image
125
+ )
126
 
127
  demo.launch()