theostos commited on
Commit
7618ac2
·
1 Parent(s): 033a4b3

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +86 -54
  3. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
app.py CHANGED
@@ -1,70 +1,102 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
20
 
21
- messages.extend(history)
 
 
22
 
23
- messages.append({"role": "user", "content": message})
24
 
25
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
41
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ import spaces
5
+ from threading import Thread
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, FineGrainedFP8Config, TextIteratorStreamer
7
 
8
+ # >>>> CHANGE THIS <<<<
9
+ MODEL_ID = os.getenv("MODEL_ID", "theostos/LLM4Docq-annotator")
10
 
11
+ # Matches your training style: messages=[{"role":"user","content": template.format(term=..., dependencies=...)}]
12
+ INSTRUCTION_TEMPLATE = (
13
+ "You are a Rocq code annotator. Given the Coq term and its dependencies, "
14
+ "produce helpful inline comments and explanations.\n\n"
15
+ "Term:\n{term}\n\nDependencies:\n{dependencies}\n"
16
+ )
 
 
 
 
 
 
 
17
 
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)
21
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
22
+ tokenizer.pad_token_id = tokenizer.eos_token_id
23
 
24
+ quant_config = FineGrainedFP8Config()
25
 
26
+ _model = None
27
+ def load_model():
28
+ global _model
29
+ if _model is None:
30
+ _model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ token=HF_TOKEN,
33
+ device_map="auto",
34
+ dtype="auto", # load base weights in their optimal dtype
35
+ quantization_config=quant_config, # <-- FP8 quantization
36
+ trust_remote_code=True,
37
+ )
38
+ return _model
39
 
40
+ def build_messages(term: str, deps: str):
41
+ content = INSTRUCTION_TEMPLATE.format(term=term, dependencies=deps)
42
+ return [{"role": "user", "content": content}]
 
 
 
 
 
 
 
 
43
 
44
+ # Estimate duration for ZeroGPU (default is 60s). Shorter = better queue priority.
45
+ def _duration(term, deps, temperature, top_p, max_new_tokens, repetition_penalty):
46
+ # crude: ~2.5 tok/s + 30s headroom
47
+ return int(min(300, max(60, (int(max_new_tokens) / 2.5) + 30)))
48
 
49
+ @spaces.GPU(duration=_duration)
50
+ def generate(term, deps, temperature, top_p, max_new_tokens, repetition_penalty):
51
+ model = load_model()
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
 
54
+ messages = build_messages(term, deps)
55
+ inputs = tokenizer.apply_chat_template(
56
+ messages,
57
+ tokenize=True,
58
+ add_generation_prompt=True,
59
+ return_tensors="pt",
60
+ ).to(device)
61
+
62
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
63
+ gen_kwargs = dict(
64
+ inputs=inputs,
65
+ max_new_tokens=int(max_new_tokens),
66
+ temperature=float(temperature),
67
+ top_p=float(top_p),
68
+ do_sample=True,
69
+ streamer=streamer,
70
+ pad_token_id=tokenizer.pad_token_id,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ )
73
+
74
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
75
+ thread.start()
76
 
77
+ out = ""
78
+ for token in streamer: # stream tokens to UI
79
+ out += token
80
+ yield f"```rocq\n{out}\n```"
81
 
82
+ with gr.Blocks(title="Rocq Annotator (ZeroGPU)") as demo:
83
+ gr.Markdown("# Rocq annotator\nThe model will produce annotated Rocq code.")
84
+ with gr.Row():
85
+ term = gr.Textbox(label="Prefix", lines=100, placeholder="Paste the prefix to use")
86
+ deps = gr.Textbox(label="To annotate", lines=8, placeholder="The code to annotate")
87
+ with gr.Row():
88
+ temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
89
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
90
+ max_new = gr.Slider(256, 8192, value=4096, step=32, label="max_new_tokens")
91
+ out = gr.Markdown(label="Annotated Rocq")
92
+ btn = gr.Button("Annotate")
93
+ btn.click(
94
+ generate,
95
+ inputs=[term, deps, temperature, top_p, max_new],
96
+ outputs=out,
97
+ concurrency_limit=1, # cooperate with ZeroGPU queues
98
+ )
99
+ demo.queue(max_size=20, default_concurrency_limit=1)
100
 
101
  if __name__ == "__main__":
102
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ transformers>=4.57.1
3
+ accelerate>=1.10
4
+ gradio>=4.44
5
+ spaces>=0.42