kouki321 commited on
Commit
5e72312
·
verified ·
1 Parent(s): 4eefb34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -1,26 +1,43 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
  # Model identifier
5
  model_id = "sshleifer/tiny-gpt2"
6
  #"google/flan-t5-small"
7
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
8
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
 
 
9
 
10
- # Ensure cache directory exists and is writable
11
  cache_dir = "/app/cache"
12
  os.makedirs(cache_dir, exist_ok=True)
13
 
14
- # Load tokenizer and model, pointing to custom cache
15
- tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=cache_dir)
 
 
 
 
 
 
 
 
17
 
18
- # Simple inference function
19
  def generate(text: str) -> str:
20
  inputs = tokenizer(text, return_tensors="pt")
21
- outputs = model.generate(**inputs)
22
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
- prompt = "Translate English to French: Hello, how are you?"
26
- print(generate(prompt))
 
 
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BertForMaskedLM
3
 
4
  # Model identifier
5
  model_id = "sshleifer/tiny-gpt2"
6
  #"google/flan-t5-small"
7
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
8
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
9
+
10
+ #model_id = "remi/bertabs-finetuned-extractive-abstractive-summarization"
11
 
12
+ # Ensure cache directory exists
13
  cache_dir = "/app/cache"
14
  os.makedirs(cache_dir, exist_ok=True)
15
 
16
+ # Load using appropriate class
17
+ if model_id.startswith("remi/bertabs"):
18
+ # BERT masked language model for abstractive summarization
19
+ model = BertForMaskedLM.from_pretrained(model_id, cache_dir=cache_dir)
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
21
+ else:
22
+ # Sequence-to-sequence model like Flan-T5
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=cache_dir)
25
+
26
+ # Inference function
27
 
 
28
  def generate(text: str) -> str:
29
  inputs = tokenizer(text, return_tensors="pt")
30
+ if hasattr(model, 'generate'):
31
+ outputs = model.generate(**inputs)
32
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+ else:
34
+ # For masked LM, demonstrate mask filling
35
+ from transformers import pipeline
36
+ fill = pipeline('fill-mask', model=model, tokenizer=tokenizer, cache_dir=cache_dir)
37
+ return fill(text)
38
 
39
  if __name__ == "__main__":
40
+ # Example usage
41
+ prompt = "The meaning of life is <mask>."
42
+ result = generate(prompt)
43
+ print(result)