nvhuynh16 commited on
Commit
add12a3
·
verified ·
1 Parent(s): 1e5532c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -32
app.py CHANGED
@@ -1,26 +1,63 @@
1
  """
2
- Gradio demo for Gemma Code Generator using HuggingFace Inference API.
3
- This runs serverless on HF infrastructure - no GPU costs!
4
  """
5
 
6
  import gradio as gr
7
- from huggingface_hub import InferenceClient
 
 
 
8
 
9
  # Model configuration
10
- MODEL_NAME = "nvhuynh16/gemma-2b-code-alpaca-best" # Best checkpoint (step 2000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Initialize Inference client (no model specified here - we'll pass it per request)
13
- client = InferenceClient()
14
 
15
 
16
  def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7):
17
- """Generate code from instruction using HF Inference API"""
18
 
19
  if not instruction.strip():
20
  return "Please enter an instruction."
21
 
22
- # Format prompt in Alpaca style
23
- prompt = f"""### Instruction:
 
 
 
 
24
  {instruction}
25
 
26
  ### Input:
@@ -29,32 +66,33 @@ def generate_code(instruction: str, max_tokens: int = 256, temperature: float =
29
  ### Response:
30
  """
31
 
32
- try:
33
- # Generate using HF Inference API
34
- response = client.text_generation(
35
- prompt,
36
- model=MODEL_NAME,
37
- max_new_tokens=max_tokens,
38
- temperature=temperature,
39
- top_p=0.9,
40
- do_sample=True,
41
- return_full_text=False,
42
- )
 
 
43
 
44
- return response.strip()
 
45
 
46
- except Exception as e:
47
- error_msg = str(e)
48
- if "410" in error_msg or "Gone" in error_msg:
49
- return "⚠️ API endpoint error. This usually means the Inference API is updating. Please try again in a moment."
50
- elif "Model too large" in error_msg or "not currently loaded" in error_msg or "loading" in error_msg.lower():
51
- return "⏳ Model is loading (first request takes 1-2 minutes). Please try again in a moment."
52
- elif "rate limit" in error_msg.lower():
53
- return "⚠️ Rate limit reached. Please wait a few minutes and try again."
54
- elif "404" in error_msg or "not found" in error_msg.lower():
55
- return "⚠️ Model not found or not enabled for Inference API. Please check the model settings on HuggingFace."
56
  else:
57
- return f"Error: {error_msg}\n\nPlease try again. If the issue persists, the model may be loading for the first time."
 
 
 
 
 
58
 
59
 
60
  # Custom CSS for better appearance
 
1
  """
2
+ Gradio demo for Gemma Code Generator.
3
+ Loads the fine-tuned model directly using PEFT.
4
  """
5
 
6
  import gradio as gr
7
+ import torch
8
+ import os
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from peft import PeftModel
11
 
12
  # Model configuration
13
+ BASE_MODEL = "google/gemma-2-2b-it"
14
+ ADAPTER_MODEL = "nvhuynh16/gemma-2b-code-alpaca-best"
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
+
17
+ # Global variables for lazy loading
18
+ tokenizer = None
19
+ model = None
20
+
21
+ def load_model():
22
+ """Lazy load model on first request"""
23
+ global tokenizer, model
24
+
25
+ if model is None:
26
+ print("Loading model for the first time...")
27
+
28
+ # Load tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
30
+
31
+ # Load base model with 4-bit quantization
32
+ base_model = AutoModelForCausalLM.from_pretrained(
33
+ BASE_MODEL,
34
+ device_map="auto",
35
+ torch_dtype=torch.float16,
36
+ load_in_4bit=True,
37
+ token=HF_TOKEN
38
+ )
39
+
40
+ # Load LoRA adapter
41
+ model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, token=HF_TOKEN)
42
+ model.eval()
43
+
44
+ print("Model loaded successfully!")
45
 
46
+ return tokenizer, model
 
47
 
48
 
49
  def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7):
50
+ """Generate code from instruction"""
51
 
52
  if not instruction.strip():
53
  return "Please enter an instruction."
54
 
55
+ try:
56
+ # Load model (cached after first call)
57
+ tok, mdl = load_model()
58
+
59
+ # Format prompt in Alpaca style
60
+ prompt = f"""### Instruction:
61
  {instruction}
62
 
63
  ### Input:
 
66
  ### Response:
67
  """
68
 
69
+ # Tokenize
70
+ inputs = tok(prompt, return_tensors="pt").to(mdl.device)
71
+
72
+ # Generate
73
+ with torch.no_grad():
74
+ outputs = mdl.generate(
75
+ **inputs,
76
+ max_new_tokens=max_tokens,
77
+ temperature=temperature,
78
+ top_p=0.9,
79
+ do_sample=True,
80
+ pad_token_id=tok.eos_token_id,
81
+ )
82
 
83
+ # Decode
84
+ generated = tok.decode(outputs[0], skip_special_tokens=True)
85
 
86
+ # Extract code after "### Response:"
87
+ if "### Response:" in generated:
88
+ code = generated.split("### Response:")[-1].strip()
 
 
 
 
 
 
 
89
  else:
90
+ code = generated.strip()
91
+
92
+ return code
93
+
94
+ except Exception as e:
95
+ return f"Error: {str(e)}\n\nPlease try again."
96
 
97
 
98
  # Custom CSS for better appearance