Invescoz commited on
Commit
502b633
·
verified ·
1 Parent(s): 7f1f452

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -24
app.py CHANGED
@@ -1,35 +1,169 @@
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
- MODEL_NAME = "Rapnss/VIA-01" # your uploaded model repo on HF
6
 
7
- # Load model & tokenizer
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- MODEL_NAME,
11
- torch_dtype=torch.float16,
12
- device_map="auto"
13
- )
 
14
 
15
- def chat_fn(prompt, history=[]):
16
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
17
- outputs = model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  **inputs,
19
- max_new_tokens=512,
20
- temperature=0.7,
21
- top_p=0.9,
 
 
 
 
 
 
 
22
  )
 
 
 
 
 
23
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return response
 
 
25
 
26
- demo = gr.Interface(
27
- fn=chat_fn,
28
- inputs=gr.Textbox(lines=2, placeholder="Ask VIA-01 something..."),
29
- outputs=gr.Textbox(label="VIA-01 Response"),
30
- title="Rapnss VIA-01",
31
- description="Lightweight Reasoning + Code Model, by Rapnss"
32
- )
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
- demo.launch()
 
1
+ # app.py
2
+ import time
3
+ import os
4
  import gradio as gr
 
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
7
 
8
+ MODEL_NAME = "Rapnss/VIA-01" # your HF repo
9
 
10
+ # Configs you can tune
11
+ DEFAULT_MAX_NEW_TOKENS = 64 # keep low to meet latency targets
12
+ MAX_PROMPT_TOKENS = 512 # truncate long prompts
13
+ TEMPERATURE = 0.3
14
+ TOP_P = 0.9
15
+ DO_SAMPLE = False # deterministic and usually faster than sampling
16
+ NUM_BEAMS = 1 # beam=1 is fastest
17
+ WARMUP_PROMPT = "Hello." # used to warm model after loading
18
 
19
+ # Try to load tokenizer / model in quantized mode (4-bit) if bitsandbytes available
20
+ print("Loading tokenizer & model...")
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
23
+
24
+ model = None
25
+ device = "cpu"
26
+ try:
27
+ # If CUDA is available and bitsandbytes exists, load 4-bit
28
+ if torch.cuda.is_available():
29
+ device = "cuda"
30
+ print("CUDA available — attempting 4-bit load with bitsandbytes...")
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_NAME,
33
+ load_in_4bit=True,
34
+ device_map="auto",
35
+ torch_dtype=torch.float16,
36
+ trust_remote_code=True, # some user repos need it
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_use_double_quant=True,
39
+ )
40
+ else:
41
+ raise RuntimeError("CUDA not available; load fallback")
42
+ except Exception as e:
43
+ print("4-bit load failed or not available:", e)
44
+ print("Falling back to fp16/cpu (best-effort).")
45
+ # fallback: try fp16 on GPU or float32 on CPU
46
+ if torch.cuda.is_available():
47
+ device = "cuda"
48
+ model = AutoModelForCausalLM.from_pretrained(
49
+ MODEL_NAME,
50
+ torch_dtype=torch.float16,
51
+ device_map="auto",
52
+ trust_remote_code=True,
53
+ )
54
+ else:
55
+ device = "cpu"
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ MODEL_NAME,
58
+ torch_dtype=torch.float32,
59
+ device_map={"": "cpu"},
60
+ trust_remote_code=True,
61
+ )
62
+
63
+ # Put model to eval & optionally compile
64
+ model.eval()
65
+
66
+ # Optional: try torch.compile for small speedups (PyTorch 2.x only, may increase startup)
67
+ try:
68
+ if torch.__version__.startswith("2"):
69
+ print("Attempting torch.compile(model) for runtime speedups...")
70
+ model = torch.compile(model)
71
+ except Exception as e:
72
+ print("torch.compile not used:", e)
73
+
74
+ print(f"Model loaded on {device}")
75
+
76
+ # Utility: fast tokenize + move to proper device
77
+ def prepare_inputs(prompt_text):
78
+ # Truncate long prompts to limit total tokens on generation
79
+ inputs = tokenizer(
80
+ prompt_text,
81
+ return_tensors="pt",
82
+ truncation=True,
83
+ max_length=MAX_PROMPT_TOKENS,
84
+ padding=False,
85
+ )
86
+ if device == "cuda":
87
+ inputs = {k: v.cuda() for k, v in inputs.items()}
88
+ return inputs
89
+
90
+ # Optional: short stopping criteria example (stop on newline double)
91
+ class StopOnDoubleNewline(StoppingCriteria):
92
+ def __call__(self, input_ids, scores, **kwargs):
93
+ # stop when last two tokens are newline + newline (customize as needed)
94
+ if input_ids.shape[-1] >= 2:
95
+ if input_ids[0, -2].item() == tokenizer.eos_token_id or input_ids[0, -1].item() == tokenizer.eos_token_id:
96
+ return True
97
+ return False
98
+
99
+ stop_criteria = StoppingCriteriaList([StopOnDoubleNewline()])
100
+
101
+ # Warm-up function (to run a single tiny generation so the model caches kernels)
102
+ def warm_up_model():
103
+ try:
104
+ prompt = WARMUP_PROMPT
105
+ inputs = prepare_inputs(prompt)
106
+ with torch.inference_mode():
107
+ model.generate(
108
+ **inputs,
109
+ max_new_tokens=8,
110
+ do_sample=False,
111
+ use_cache=True,
112
+ )
113
+ print("Warmup complete.")
114
+ except Exception as e:
115
+ print("Warmup failed:", e)
116
+
117
+ # Warm up once at startup to reduce first-request latency
118
+ warm_up_model()
119
+
120
+ # The actual chat function used by Gradio
121
+ def chat_fn(prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = TEMPERATURE):
122
+ t0 = time.time()
123
+ prompt = prompt.strip()
124
+ if not prompt:
125
+ return "Please enter a prompt."
126
+
127
+ # safety: clamp max_new_tokens to avoid huge generations
128
+ max_new_tokens = int(max(1, min(max_new_tokens, 256)))
129
+
130
+ inputs = prepare_inputs(prompt)
131
+
132
+ # Generation arguments tuned for speed
133
+ gen_kwargs = dict(
134
  **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ temperature=float(temperature),
137
+ top_p=float(TOP_P),
138
+ do_sample=DO_SAMPLE,
139
+ num_beams=NUM_BEAMS,
140
+ eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id,
141
+ pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
142
+ use_cache=True,
143
+ early_stopping=True,
144
+ # stopping_criteria=stop_criteria, # enable if you want custom stopping
145
  )
146
+
147
+ # Inference context to reduce overhead
148
+ with torch.inference_mode():
149
+ outputs = model.generate(**gen_kwargs)
150
+
151
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
152
+ latency = time.time() - t0
153
+ # Return response and latency for debugging
154
+ return f"{response}\n\n---\nLatency: {latency:.2f}s (max_new_tokens={max_new_tokens}, device={device})"
155
 
156
+ # Gradio UI
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown("# Rapnss VIA-01")
159
+ with gr.Row():
160
+ txt = gr.Textbox(lines=3, placeholder="Ask VIA-01 something...", label="Prompt")
161
+ with gr.Row():
162
+ max_tokens = gr.Slider(16, 256, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens")
163
+ temp = gr.Slider(0.0, 1.0, value=TEMPERATURE, step=0.05, label="Temperature")
164
+ out = gr.Textbox(label="VIA-01 Response", lines=12)
165
+ submit = gr.Button("Generate")
166
+ submit.click(fn=chat_fn, inputs=[txt, max_tokens, temp], outputs=out)
167
 
168
  if __name__ == "__main__":
169
+ demo.launch(share=False, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))