rahul7star commited on
Commit
ff12e01
·
verified ·
1 Parent(s): 9274c9f

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +23 -22
app_flash.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as hf_pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
 
4
 
5
  # ============================================================
6
  # 1️⃣ FlashPack-enabled model class
@@ -9,32 +10,32 @@ class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin)
9
  pass
10
 
11
  # ============================================================
12
- # 2️⃣ Model & tokenizer loading
13
  # ============================================================
14
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
15
- FLASHPACK_REPO = "rahul7star/FlashPack" # Upload target repo
16
 
 
 
 
 
 
 
17
  try:
18
- # Try loading directly from the FlashPack repo
19
  print("📂 Loading model from FlashPack repository...")
20
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
- except Exception as e:
23
- print(f"⚠️ Could not load FlashPack model: {e}")
24
- print("⚙️ Loading from HF Hub and saving FlashPack to the repository...")
25
-
26
  # Load from HF Hub
27
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
29
-
30
- # Save directly to the Hugging Face repo
31
  model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
32
- print(f"✅ Model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")
33
 
34
  # ============================================================
35
- # 3️⃣ Text-generation pipeline
36
  # ============================================================
37
- pipe = hf_pipeline(
38
  "text-generation",
39
  model=model,
40
  tokenizer=tokenizer,
@@ -42,18 +43,18 @@ pipe = hf_pipeline(
42
  )
43
 
44
  # ============================================================
45
- # 4️⃣ Prompt enhancement function
46
  # ============================================================
47
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
48
  chat_history = chat_history or []
49
 
50
- # Build chat-template messages
51
  messages = [
52
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
53
  {"role": "user", "content": user_prompt},
54
  ]
55
 
56
- # Apply tokenizer chat template
57
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
58
 
59
  # Generate output
@@ -65,13 +66,13 @@ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
65
  )
66
  enhanced = outputs[0]["generated_text"].strip()
67
 
68
- # Append to chat history
69
  chat_history.append({"role": "user", "content": user_prompt})
70
  chat_history.append({"role": "assistant", "content": enhanced})
71
  return chat_history
72
 
73
  # ============================================================
74
- # 5️⃣ Gradio UI
75
  # ============================================================
76
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
77
  gr.Markdown(
@@ -95,7 +96,7 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
95
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
96
  clear_btn = gr.Button("🧹 Clear Chat")
97
 
98
- # Bind UI actions
99
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
100
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
101
  clear_btn.click(lambda: [], None, chatbot)
@@ -110,7 +111,7 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
110
  )
111
 
112
  # ============================================================
113
- # 6️⃣ Launch
114
  # ============================================================
115
  if __name__ == "__main__":
116
  demo.launch(show_error=True)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
+ from transformers import AutoModelForCausalLM
5
 
6
  # ============================================================
7
  # 1️⃣ FlashPack-enabled model class
 
10
  pass
11
 
12
  # ============================================================
13
+ # 2️⃣ Model & tokenizer settings
14
  # ============================================================
15
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
16
+ FLASHPACK_REPO = "rahul7star/FlashPack"
17
 
18
+ # Load tokenizer
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
+
21
+ # ============================================================
22
+ # 3️⃣ Load or create FlashPack model
23
+ # ============================================================
24
  try:
 
25
  print("📂 Loading model from FlashPack repository...")
26
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
27
+ except FileNotFoundError:
28
+ print("⚠️ FlashPack model not found on Hub. Creating and uploading...")
 
 
 
29
  # Load from HF Hub
 
30
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
31
+ # Save as FlashPack directly to Hub
 
32
  model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
33
+ print(f"✅ Model uploaded as FlashPack to Hugging Face Hub: {FLASHPACK_REPO}")
34
 
35
  # ============================================================
36
+ # 4️⃣ Text-generation pipeline
37
  # ============================================================
38
+ pipe = pipeline(
39
  "text-generation",
40
  model=model,
41
  tokenizer=tokenizer,
 
43
  )
44
 
45
  # ============================================================
46
+ # 5️⃣ Prompt enhancement function
47
  # ============================================================
48
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
49
  chat_history = chat_history or []
50
 
51
+ # Build messages
52
  messages = [
53
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
54
  {"role": "user", "content": user_prompt},
55
  ]
56
 
57
+ # Apply chat-template
58
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
 
60
  # Generate output
 
66
  )
67
  enhanced = outputs[0]["generated_text"].strip()
68
 
69
+ # Update chat history
70
  chat_history.append({"role": "user", "content": user_prompt})
71
  chat_history.append({"role": "assistant", "content": enhanced})
72
  return chat_history
73
 
74
  # ============================================================
75
+ # 6️⃣ Gradio UI
76
  # ============================================================
77
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
78
  gr.Markdown(
 
96
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
97
  clear_btn = gr.Button("🧹 Clear Chat")
98
 
99
+ # Bind actions
100
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
101
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
102
  clear_btn.click(lambda: [], None, chatbot)
 
111
  )
112
 
113
  # ============================================================
114
+ # 7️⃣ Launch
115
  # ============================================================
116
  if __name__ == "__main__":
117
  demo.launch(show_error=True)