Madras1 commited on
Commit
cb91d24
·
verified ·
1 Parent(s): 9e48a7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -70
app.py CHANGED
@@ -1,44 +1,76 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- import os
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from groq import Groq
7
 
8
- # --- Configurações Locais (H200) ---
9
- LOCAL_MODEL_ID = "Qwen/Qwen2.5-Coder-32B-Instruct"
 
 
 
 
 
 
 
 
10
  model = None
11
  tokenizer = None
12
 
13
- # --- Configuração Groq ---
14
- # Ele tenta pegar a chave dos segredos do Space
15
- groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
16
-
17
- # --- Função 1: Roda na H200 (Gasta Cota) ---
18
- # Diminuí para 60s para ajudar no seu reset do Colab
19
- @spaces.GPU(duration=60)
20
- def run_local_qwen(messages):
21
  global model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Lazy Loading
24
- if model is None:
25
- print(f"🚀 Carregando {LOCAL_MODEL_ID} na H200...")
26
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_ID)
 
 
 
 
 
 
27
  model = AutoModelForCausalLM.from_pretrained(
28
- LOCAL_MODEL_ID,
29
  torch_dtype=torch.bfloat16,
30
  device_map="cuda"
31
  )
32
-
33
- # Prepara prompt
 
 
 
 
 
 
 
 
 
 
 
34
  text = tokenizer.apply_chat_template(
35
  messages,
36
  tokenize=False,
37
  add_generation_prompt=True
38
  )
 
39
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
40
 
41
- # Gera
42
  outputs = model.generate(
43
  **inputs,
44
  max_new_tokens=1024,
@@ -49,64 +81,21 @@ def run_local_qwen(messages):
49
  response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
50
  return response
51
 
52
- # --- Função 2: Roda no Groq (NÃO Gasta Cota da GPU) ---
53
- def run_groq(messages, model_id="llama3-70b-8192"):
54
- print(f"⚡ Chamando Groq: {model_id}...")
55
- try:
56
- completion = groq_client.chat.completions.create(
57
- model=model_id,
58
- messages=messages,
59
- temperature=0.7,
60
- max_tokens=1024,
61
- top_p=1,
62
- stream=False,
63
- stop=None,
64
- )
65
- return completion.choices[0].message.content
66
- except Exception as e:
67
- return f"❌ Erro no Groq: {str(e)}"
68
-
69
- # --- O Roteador Central (A Inteligência) ---
70
- def router(message, history, model_selector):
71
- # Formata histórico para padrão OpenAI/Groq
72
- messages = []
73
- for user_msg, bot_msg in history:
74
- if user_msg: messages.append({"role": "user", "content": user_msg})
75
- if bot_msg: messages.append({"role": "assistant", "content": bot_msg})
76
- messages.append({"role": "user", "content": message})
77
-
78
- # A Lógica de Roteamento
79
- if model_selector == "Local: Qwen 2.5 32B (H200)":
80
- return run_local_qwen(messages)
81
-
82
- elif model_selector == "Groq: Llama 3 70B":
83
- return run_groq(messages, "llama3-70b-8192")
84
-
85
- elif model_selector == "Groq: Mixtral 8x7B":
86
- return run_groq(messages, "mixtral-8x7b-32768")
87
-
88
- else:
89
- return "Modelo não reconhecido."
90
-
91
  # --- Interface ---
92
  with gr.Blocks() as demo:
93
- gr.Markdown("# 🔀 APIDOST Router")
94
- gr.Markdown("Roteamento híbrido: H200 Local (ZeroGPU) + Groq Cloud (LPU)")
95
 
96
  with gr.Row():
97
- model_dropdown = gr.Dropdown(
98
- choices=[
99
- "Local: Qwen 2.5 32B (H200)",
100
- "Groq: Llama 3 70B",
101
- "Groq: Mixtral 8x7B"
102
- ],
103
- value="Groq: Llama 3 70B", # Padrão no Groq pra economizar sua cota
104
- label="Escolha o Cérebro"
105
  )
106
 
107
  chat = gr.ChatInterface(
108
  fn=router,
109
- additional_inputs=[model_dropdown] # Passa o dropdown pro router
110
  )
111
 
112
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ import gc
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
 
7
+ # --- CATÁLOGO DE MODELOS ---
8
+ # Adicione quantos quiser aqui (que caibam na VRAM um por vez)
9
+ MODEL_MAP = {
10
+ "qwen-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
11
+ "llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
12
+ "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3"
13
+ }
14
+
15
+ # --- Estado Global ---
16
+ current_model_id = None
17
  model = None
18
  tokenizer = None
19
 
20
+ # --- Função de Limpeza de VRAM ---
21
+ def free_memory():
 
 
 
 
 
 
22
  global model, tokenizer
23
+ if model is not None:
24
+ del model
25
+ del tokenizer
26
+ gc.collect()
27
+ torch.cuda.empty_cache()
28
+ print("🧹 VRAM limpa!")
29
+
30
+ # --- A Mágica do Roteamento na GPU ---
31
+ # Aumentei a duration para 90s porque trocar de modelo gasta uns 20s
32
+ @spaces.GPU(duration=90)
33
+ def router(message, history, model_name_key):
34
+ global model, tokenizer, current_model_id
35
+
36
+ target_id = MODEL_MAP.get(model_name_key)
37
 
38
+ if not target_id:
39
+ return f"❌ Erro: Modelo '{model_name_key}' não encontrado no catálogo."
40
+
41
+ # --- LÓGICA DE SWAP (TROCA) ---
42
+ if current_model_id != target_id:
43
+ print(f"🔄 Trocando de {current_model_id} para {target_id}...")
44
+ free_memory() # Esvazia a GPU
45
+
46
+ print("🚀 Carregando novo modelo...")
47
+ tokenizer = AutoTokenizer.from_pretrained(target_id)
48
  model = AutoModelForCausalLM.from_pretrained(
49
+ target_id,
50
  torch_dtype=torch.bfloat16,
51
  device_map="cuda"
52
  )
53
+ current_model_id = target_id
54
+ print("✅ Modelo carregado!")
55
+ else:
56
+ print("⚡ Modelo já está na VRAM. Usando cache.")
57
+
58
+ # --- INFERÊNCIA ---
59
+ # Formata histórico
60
+ messages = []
61
+ for user_msg, bot_msg in history:
62
+ if user_msg: messages.append({"role": "user", "content": user_msg})
63
+ if bot_msg: messages.append({"role": "assistant", "content": bot_msg})
64
+ messages.append({"role": "user", "content": message})
65
+
66
  text = tokenizer.apply_chat_template(
67
  messages,
68
  tokenize=False,
69
  add_generation_prompt=True
70
  )
71
+
72
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
73
 
 
74
  outputs = model.generate(
75
  **inputs,
76
  max_new_tokens=1024,
 
81
  response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
82
  return response
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # --- Interface ---
85
  with gr.Blocks() as demo:
86
+ gr.Markdown("# 🧠 Gabriel's Multi-Model Switcher")
 
87
 
88
  with gr.Row():
89
+ # Dropdown para escolher qual modelo do HF carregar
90
+ model_selector = gr.Dropdown(
91
+ choices=list(MODEL_MAP.keys()),
92
+ value="qwen-32b",
93
+ label="Escolha o Modelo (Isso faz swap na GPU)"
 
 
 
94
  )
95
 
96
  chat = gr.ChatInterface(
97
  fn=router,
98
+ additional_inputs=[model_selector]
99
  )
100
 
101
  if __name__ == "__main__":