sedtha commited on
Commit
4b22381
Β·
verified Β·
1 Parent(s): 5b2614b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -38
app.py CHANGED
@@ -1,47 +1,28 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, MBartForConditionalGeneration
3
- from peft import PeftModel, PeftConfig
 
 
 
4
  import torch
5
-
6
  # ==========================
7
  # 1. Load model from Hugging Face
8
  # ==========================
9
 
10
- MODEL_NAME = "sedtha/mBart-50-large_LoRa_kh_sumerize"
11
 
12
  print("Loading model and tokenizer...")
 
 
13
 
14
- try:
15
- # First, load the base model configuration to get the base model name
16
- config = PeftConfig.from_pretrained(MODEL_NAME)
17
- base_model_name = config.base_model_name_or_path
18
-
19
- # Load tokenizer from the base model (mbart-large-50)
20
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
21
-
22
- # Load the base model
23
- base_model = MBartForConditionalGeneration.from_pretrained(base_model_name)
24
-
25
- # Load the LoRA adapter
26
- model = PeftModel.from_pretrained(base_model, MODEL_NAME)
27
-
28
- # Merge LoRA weights with base model for inference (optional but can improve performance)
29
- model = model.merge_and_unload()
30
-
31
- except Exception as e:
32
- print(f"Error loading model: {e}")
33
- # Fallback: try direct loading
34
- try:
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
36
- model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME)
37
- except Exception as e2:
38
- print(f"Fallback loading also failed: {e2}")
39
- raise
40
 
41
  # Move to GPU if available
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
  model = model.to(device)
44
- model.eval() # Set to evaluation mode
45
 
46
  print(f"βœ… Model loaded successfully on {device}!")
47
 
@@ -59,10 +40,6 @@ def summarize_khmer_text(text, max_length=150, min_length=40):
59
  return "⚠️ αž’αžαŸ’αžαž”αž‘αžαŸ’αž›αžΈαž–αŸαž€ / Text is too short to summarize"
60
 
61
  try:
62
- # Set the source language for mBART (Khmer)
63
- # For mBART-50, Khmer language code is "km_KR"
64
- tokenizer.src_lang = "km_KR"
65
-
66
  # Tokenize input
67
  inputs = tokenizer(
68
  text,
@@ -81,8 +58,7 @@ def summarize_khmer_text(text, max_length=150, min_length=40):
81
  length_penalty=2.0,
82
  num_beams=4,
83
  early_stopping=True,
84
- no_repeat_ngram_size=3,
85
- forced_bos_token_id=tokenizer.lang_code_to_id["km_KR"] # Force Khmer output
86
  )
87
 
88
  # Decode output
@@ -157,4 +133,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
157
  # 4. Launch
158
  # ==========================
159
  if __name__ == "__main__":
160
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from transformers import (
4
+ MBartForConditionalGeneration, MBart50Tokenizer,
5
+ MT5ForConditionalGeneration, T5Tokenizer
6
+ )
7
  import torch
8
+ from peft import PeftModel
9
  # ==========================
10
  # 1. Load model from Hugging Face
11
  # ==========================
12
 
13
+ MODEL_NAME = "sedtha/mBart-50-large_LoRa_kh_sumerize" # e.g., "Sedtha-019/khmer-summarization"
14
 
15
  print("Loading model and tokenizer...")
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ # model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
18
 
19
+
20
+ base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
21
+ model = PeftModel.from_pretrained(base, MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Move to GPU if available
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model = model.to(device)
 
26
 
27
  print(f"βœ… Model loaded successfully on {device}!")
28
 
 
40
  return "⚠️ αž’αžαŸ’αžαž”αž‘αžαŸ’αž›αžΈαž–αŸαž€ / Text is too short to summarize"
41
 
42
  try:
 
 
 
 
43
  # Tokenize input
44
  inputs = tokenizer(
45
  text,
 
58
  length_penalty=2.0,
59
  num_beams=4,
60
  early_stopping=True,
61
+ no_repeat_ngram_size=3
 
62
  )
63
 
64
  # Decode output
 
133
  # 4. Launch
134
  # ==========================
135
  if __name__ == "__main__":
136
+ demo.launch(share=True)