DINGDINGBELLS commited on
Commit
dea6353
·
verified ·
1 Parent(s): e674be8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -53
app.py CHANGED
@@ -1,63 +1,32 @@
1
- import gradio as gr
2
  import torch
3
- import gc
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
- from threading import Thread
6
 
7
- # Чистка памяти перед стартом
8
- gc.collect()
9
-
10
- MODEL_ID = "."
11
 
12
  print("🍌 BananaGPT: Попытка загрузки в float16 (Эконом-режим)...")
13
 
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
15
 
16
- # Используем dtype=torch.float16, чтобы веса весили в 2 раза меньше
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
- MODEL_ID,
19
- device_map="auto", # Теперь, когда accelerate в requirements, это сработает эффективно
20
- torch_dtype=torch.float16,
21
  low_cpu_mem_usage=True
22
  )
23
 
24
- # Красивый интерфейс
25
- custom_css = """
26
- footer {visibility: hidden}
27
- .gradio-container {background-color: #0b1117 !important; color: #e6edf3 !important;}
28
- .main-title {text-align: center; color: #f1c40f; font-size: 2.5em; font-weight: bold; margin-bottom: 20px;}
29
- .message.user {border: 1px solid #30363d !important;}
30
- .message.bot {background-color: #21262d !important; border: 1px solid #30363d !important;}
31
- """
32
-
33
- def predict(message, history):
34
- # Ограничиваем вход, чтобы не вешать процессор
35
- inputs = tokenizer(message, return_tensors="pt").to(model.device)
36
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
-
38
- generate_kwargs = dict(
39
- inputs,
40
- streamer=streamer,
41
- max_new_tokens=80,
42
- temperature=0.7,
43
- do_sample=True,
44
- )
45
-
46
- t = Thread(target=model.generate, kwargs=generate_kwargs)
47
- t.start()
48
-
49
- partial_message = ""
50
- for new_token in streamer:
51
- partial_message += new_token
52
- yield partial_message
53
-
54
- with gr.Blocks(css=custom_css, title="BananaGPT") as demo:
55
- gr.HTML("<div class='main-title'>🍌 BananaGPT</div>")
56
-
57
- gr.ChatInterface(
58
- fn=predict,
59
- type="messages",
60
- )
61
-
62
- if __name__ == "__main__":
63
- demo.queue().launch(show_api=False)
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
+ model_name = азвание_твоей_модели" # Например, "gpt2" или путь к папке
 
 
 
5
 
6
  print("🍌 BananaGPT: Попытка загрузки в float16 (Эконом-режим)...")
7
 
8
+ # Загружаем токенизатор
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
+ # Загружаем модель
12
+ # ВНИМАНИЕ: заменено torch_dtype на dtype, чтобы не было ворнингов
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ dtype=torch.float16,
 
16
  low_cpu_mem_usage=True
17
  )
18
 
19
+ # Переносим на видеокарту, если она есть
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model.to(device)
22
+
23
+ print(f"✅ Модель успешно загружена на {device}!")
24
+
25
+ # Тестовый запуск
26
+ prompt = "Привет, BananaGPT!"
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
28
+
29
+ with torch.no_grad():
30
+ output = model.generate(**inputs, max_new_tokens=50)
31
+
32
+ print(tokenizer.decode(output[0], skip_special_tokens=True))