tniranjan commited on
Commit
ac4bdc5
·
verified ·
1 Parent(s): b74fa2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -15
app.py CHANGED
@@ -1,31 +1,45 @@
1
  from operator import ge
2
  from xml.dom.expatbuilder import theDOMImplementation
3
  import gradio as gr
4
- from huggingface_hub import InferenceClient
5
  import os
 
 
6
 
7
- def generate(
8
- model_name,
9
- text,
10
- max_new_tokens,
11
- top_k
12
- ):
13
  if model_name == "Medium-GPTNeo":
14
- model = "tniranjan/finetuned_gptneo-base-tinystories-ta_v3"
15
  elif model_name == "Small-GPTNeo":
16
- model = "tniranjan/finetuned_tinystories_33M_tinystories_ta"
17
  elif model_name == "Small-LLaMA":
18
- model = "tniranjan/finetuned_Llama_tinystories_tinystories_ta"
19
- client = InferenceClient(provider="hf-inference",
20
- api_key=os.environ["HUGGINGFACEHUB_API_TOKEN"])
 
 
 
 
 
 
21
 
22
- return client.text_generation(
23
- model=model,
24
- prompt = text,
 
 
25
  max_new_tokens=max_new_tokens,
26
  top_k=top_k,
 
 
27
  )
28
 
 
 
 
 
 
29
  demo = gr.Interface(
30
  generate,
31
  title="Kurunkathai: Tinystories in Tamil",
 
1
  from operator import ge
2
  from xml.dom.expatbuilder import theDOMImplementation
3
  import gradio as gr
 
4
  import os
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import torch
7
 
8
+ # Optional: cache loaded models to avoid reloading every time
9
+ model_cache = {}
10
+
11
+ def generate(model_name, text, max_new_tokens, top_k):
 
 
12
  if model_name == "Medium-GPTNeo":
13
+ model_id = "tniranjan/finetuned_gptneo-base-tinystories-ta_v3"
14
  elif model_name == "Small-GPTNeo":
15
+ model_id = "tniranjan/finetuned_tinystories_33M_tinystories_ta"
16
  elif model_name == "Small-LLaMA":
17
+ model_id = "tniranjan/finetuned_Llama_tinystories_tinystories_ta"
18
+
19
+ # Load model and tokenizer (from cache if available)
20
+ if model_id not in model_cache:
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(model_id)
23
+ model_cache[model_id] = (tokenizer, model)
24
+ else:
25
+ tokenizer, model = model_cache[model_id]
26
 
27
+ inputs = tokenizer(text, return_tensors="pt")
28
+
29
+ # Generate text
30
+ output = model.generate(
31
+ **inputs,
32
  max_new_tokens=max_new_tokens,
33
  top_k=top_k,
34
+ do_sample=True,
35
+ pad_token_id=tokenizer.eos_token_id,
36
  )
37
 
38
+ # Decode generated tokens
39
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
40
+
41
+ return generated_text
42
+
43
  demo = gr.Interface(
44
  generate,
45
  title="Kurunkathai: Tinystories in Tamil",