rahul7star commited on
Commit
9274c9f
·
verified ·
1 Parent(s): f901d63

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +15 -13
app_flash.py CHANGED
@@ -1,36 +1,38 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
- from transformers import AutoModelForCausalLM, pipeline as hf_pipeline
5
 
6
  # ============================================================
7
- # 1️⃣ Define FlashPack wrapper for Gemma
8
  # ============================================================
9
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
10
  pass
11
 
12
  # ============================================================
13
- # 2️⃣ Load model & tokenizer via FlashPack pipeline approach
14
  # ============================================================
15
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
16
- FLASHPACK_REPO = "rahul7star/FlashPack"
17
 
18
- # Load model from FlashPack repository (Hub or local path)
19
  try:
 
20
  print("📂 Loading model from FlashPack repository...")
21
- # Load model directly via FlashPack
22
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
  except Exception as e:
25
  print(f"⚠️ Could not load FlashPack model: {e}")
26
- print("⚙️ Falling back to HF model and saving FlashPack...")
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
29
- model.save_pretrained_flashpack("model_flashpack", push_to_hub=False)
30
- print("✅ Saved FlashPack locally for next run")
 
 
31
 
32
  # ============================================================
33
- # 3️⃣ Create text-generation pipeline
34
  # ============================================================
35
  pipe = hf_pipeline(
36
  "text-generation",
@@ -40,12 +42,12 @@ pipe = hf_pipeline(
40
  )
41
 
42
  # ============================================================
43
- # 4️⃣ Define prompt enhancement function
44
  # ============================================================
45
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
46
  chat_history = chat_history or []
47
 
48
- # Build messages with chat template
49
  messages = [
50
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
51
  {"role": "user", "content": user_prompt},
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as hf_pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
 
4
 
5
  # ============================================================
6
+ # 1️⃣ FlashPack-enabled model class
7
  # ============================================================
8
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
9
  pass
10
 
11
  # ============================================================
12
+ # 2️⃣ Model & tokenizer loading
13
  # ============================================================
14
  MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
15
+ FLASHPACK_REPO = "rahul7star/FlashPack" # Upload target repo
16
 
 
17
  try:
18
+ # Try loading directly from the FlashPack repo
19
  print("📂 Loading model from FlashPack repository...")
 
20
  model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
  except Exception as e:
23
  print(f"⚠️ Could not load FlashPack model: {e}")
24
+ print("⚙️ Loading from HF Hub and saving FlashPack to the repository...")
25
+
26
+ # Load from HF Hub
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
29
+
30
+ # Save directly to the Hugging Face repo
31
+ model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
32
+ print(f"✅ Model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")
33
 
34
  # ============================================================
35
+ # 3️⃣ Text-generation pipeline
36
  # ============================================================
37
  pipe = hf_pipeline(
38
  "text-generation",
 
42
  )
43
 
44
  # ============================================================
45
+ # 4️⃣ Prompt enhancement function
46
  # ============================================================
47
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
48
  chat_history = chat_history or []
49
 
50
+ # Build chat-template messages
51
  messages = [
52
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
53
  {"role": "user", "content": user_prompt},