hari7261 commited on
Commit
2185ead
·
verified ·
1 Parent(s): 361b22f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -1,18 +1,41 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  import torch
 
 
 
 
3
 
4
- model_name = "hari7261/TechChat"
5
 
6
- # If the repo is private, add your token:
7
- token = "your_huggingface_token_here" # or set HF_TOKEN env var
 
8
 
9
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
10
- model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=token)
11
 
12
- prompt = "Hello, how can I help you today?"
13
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
14
 
15
- with torch.no_grad():
16
- outputs = model.generate(**inputs, max_length=50)
 
 
 
 
 
 
17
 
18
- print(tokenizer.decode(outputs[0], skip_special_tokens=True))
 
 
1
+ import os
2
+ import gradio as gr
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # Read Hugging Face token from environment variable
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
 
9
+ MODEL_NAME = "hari7261/TechChat"
10
 
11
+ print("Loading tokenizer and model...")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
 
18
+ def generate_text(prompt):
19
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
20
+ outputs = model.generate(
21
+ **inputs,
22
+ max_length=150,
23
+ do_sample=True,
24
+ temperature=0.7,
25
+ top_p=0.9,
26
+ eos_token_id=tokenizer.eos_token_id
27
+ )
28
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+ return text
30
 
31
+ # Gradio UI
32
+ iface = gr.Interface(
33
+ fn=generate_text,
34
+ inputs=gr.Textbox(lines=5, placeholder="Enter your prompt here..."),
35
+ outputs="text",
36
+ title="TechChat - Mistral 7B",
37
+ description="Generate text with hari7261/TechChat model hosted on Hugging Face."
38
+ )
39
 
40
+ if __name__ == "__main__":
41
+ iface.launch()