Gogs commited on
Commit
f709304
·
1 Parent(s): 39e8823

🔧 Switch to Inference API (no local model loading)

Browse files
Files changed (2) hide show
  1. app.py +72 -27
  2. requirements.txt +5 -4
app.py CHANGED
@@ -7,6 +7,40 @@ import torch
7
  # ============================================================================
8
 
9
  MODEL_ID = "OpceanAI/Yuuki-best"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_code(
12
  prompt: str,
@@ -16,28 +50,43 @@ def generate_code(
16
  top_k: int = 50,
17
  repetition_penalty: float = 1.1
18
  ) -> str:
19
- """Generate code using HuggingFace Inference API (no local loading)."""
 
 
 
 
20
 
21
  if not prompt or not prompt.strip():
22
  return "Please enter a code prompt."
23
 
24
  try:
25
- from huggingface_hub import InferenceClient
26
-
27
- client = InferenceClient()
28
- response = client.text_generation(
29
- prompt,
30
- model=MODEL_ID,
31
- max_new_tokens=max_new_tokens,
32
- temperature=temperature,
33
- top_p=top_p,
34
- do_sample=True
35
  )
36
 
37
- return response
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
39
  except Exception as e:
40
- return f"Generation error: {str(e)}\n\nTry model directly: https://huggingface.co/OpceanAI/Yuuki-best"
 
41
 
42
  # ============================================================================
43
  # Examples
@@ -633,9 +682,15 @@ with gr.Blocks(
633
  # Examples
634
  gr.HTML('<div id="examples-label">Try these</div>')
635
  with gr.Row(elem_id="examples-grid"):
636
- for ex in EXAMPLES:
637
- btn = gr.Button(ex[0], elem_classes=["example-btn"], size="sm")
638
- btn.click(lambda x=ex[0]: x, outputs=prompt_input)
 
 
 
 
 
 
639
 
640
  # ===== SETTINGS TAB =====
641
  with gr.Tab("Settings", id="settings"):
@@ -729,14 +784,4 @@ with gr.Blocks(
729
  gr.HTML("""
730
  <div class="score-grid">
731
  <span class="score-badge good">Agda: 55/100</span>
732
- <span class="score-badge medium">C: 20/100</span>
733
- <span class="score-badge medium">Assembly: 15/100</span>
734
- <span class="score-badge weak">Python: 8/100</span>
735
- </div>
736
- <p style="color: #666; font-size: 0.8rem; margin-top: 16px; line-height: 1.5;">
737
- Python scores low due to alphabetical dataset ordering.
738
- Average quality: 24.6/100 (+146% from checkpoint 1400).
739
- </p>
740
- """)
741
-
742
-
 
7
  # ============================================================================
8
 
9
  MODEL_ID = "OpceanAI/Yuuki-best"
10
+ MODEL_LOADED = False
11
+ model = None
12
+ tokenizer = None
13
+
14
+
15
+ def load_model():
16
+ """Load the Yuuki model with proper error handling."""
17
+ global model, tokenizer, MODEL_LOADED
18
+
19
+ if MODEL_LOADED:
20
+ return True
21
+
22
+ try:
23
+ print(f"Loading Yuuki model from {MODEL_ID}...")
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=torch.float32,
29
+ low_cpu_mem_usage=True,
30
+ trust_remote_code=True
31
+ )
32
+
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ MODEL_LOADED = True
37
+ print("Model loaded successfully!")
38
+ return True
39
+
40
+ except Exception as e:
41
+ print(f"Error loading model: {e}")
42
+ return False
43
+
44
 
45
  def generate_code(
46
  prompt: str,
 
50
  top_k: int = 50,
51
  repetition_penalty: float = 1.1
52
  ) -> str:
53
+ """Generate code completion using Yuuki."""
54
+
55
+ if not MODEL_LOADED:
56
+ if not load_model():
57
+ return "Error: Model failed to load. Please try refreshing the page."
58
 
59
  if not prompt or not prompt.strip():
60
  return "Please enter a code prompt."
61
 
62
  try:
63
+ inputs = tokenizer(
64
+ prompt,
65
+ return_tensors="pt",
66
+ truncation=True,
67
+ max_length=512
 
 
 
 
 
68
  )
69
 
70
+ with torch.no_grad():
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_new_tokens=max_new_tokens,
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ repetition_penalty=repetition_penalty,
78
+ do_sample=True,
79
+ pad_token_id=tokenizer.pad_token_id,
80
+ eos_token_id=tokenizer.eos_token_id,
81
+ num_return_sequences=1
82
+ )
83
 
84
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ return generated_text
86
+
87
  except Exception as e:
88
+ return f"Generation error: {str(e)}"
89
+
90
 
91
  # ============================================================================
92
  # Examples
 
682
  # Examples
683
  gr.HTML('<div id="examples-label">Try these</div>')
684
  with gr.Row(elem_id="examples-grid"):
685
+ ex_btn_1 = gr.Button("module Main where", elem_classes=["example-btn"], size="sm")
686
+ ex_btn_2 = gr.Button("open import Data.Nat", elem_classes=["example-btn"], size="sm")
687
+ ex_btn_3 = gr.Button("int main() {", elem_classes=["example-btn"], size="sm")
688
+ ex_btn_4 = gr.Button("def hello():", elem_classes=["example-btn"], size="sm")
689
+
690
+ ex_btn_1.click(lambda: "module Main where", outputs=prompt_input)
691
+ ex_btn_2.click(lambda: "open import Data.Nat", outputs=prompt_input)
692
+ ex_btn_3.click(lambda: "int main() {", outputs=prompt_input)
693
+ ex_btn_4.click(lambda: "def hello():", outputs=prompt_input)
694
 
695
  # ===== SETTINGS TAB =====
696
  with gr.Tab("Settings", id="settings"):
 
784
  gr.HTML("""
785
  <div class="score-grid">
786
  <span class="score-badge good">Agda: 55/100</span>
787
+ <span class="score-badge medium">C: 20/100</sp
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- gradio==5.9.1
2
- transformers==4.46.0
3
- torch>=2.5.0
4
- accelerate
 
 
1
+ gradio==4.44.1
2
+ transformers==4.45.0
3
+ torch==2.5.0
4
+ accelerate==0.34.0
5
+ huggingface-hub>=0.20.0