srisuriyas commited on
Commit
d7b61e4
·
verified ·
1 Parent(s): 5eddeb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -84
app.py CHANGED
@@ -1,89 +1,27 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
- import torch
3
  import gradio as gr
4
- from fastapi import FastAPI, Request
5
-
6
- MODEL_ID = "ibm-granite/granite-3.1-2b-instruct"
7
-
8
- # Load model + tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- MODEL_ID,
12
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
13
- device_map="auto"
14
- )
15
-
16
- DEFAULT_PARAMS = dict(
17
- max_new_tokens=512,
18
- temperature=0.2,
19
- top_p=0.95,
20
  )
21
 
22
- def format_instruct_prompt(system_msg, user_msg):
23
- # Works with Granite chat/instruct style
24
- # Adjust if your prompt format differs
25
- return f"<|system|>\n{system_msg}\n<|user|>\n{user_msg}\n<|assistant|>\n"
26
-
27
- def generate_once(system, user, params=None):
28
- if params is None:
29
- params = {}
30
- merged = {**DEFAULT_PARAMS, **params}
31
- prompt = format_instruct_prompt(system, user)
32
-
33
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
34
- output_ids = model.generate(
35
- **inputs,
36
- max_new_tokens=merged["max_new_tokens"],
37
- do_sample=merged["temperature"] > 0,
38
- temperature=merged["temperature"],
39
- top_p=merged["top_p"],
40
- pad_token_id=tokenizer.eos_token_id,
41
- eos_token_id=tokenizer.eos_token_id,
42
- )
43
 
44
- # Only decode the newly generated tokens
45
- gen_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
46
- text = tokenizer.decode(gen_ids, skip_special_tokens=True)
47
- return text.strip()
48
-
49
- # ---------- Gradio UI (manual testing) ----------
50
- with gr.Blocks() as demo:
51
- gr.Markdown("# Granite RAG API (UI)")
52
- sys_in = gr.Textbox(label="System", value="You are Granite, a helpful and concise assistant.")
53
- usr_in = gr.Textbox(label="User", placeholder="Ask something...")
54
- max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
55
- temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="temperature")
56
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
57
- out = gr.Textbox(label="Output")
58
-
59
- def _ui(system, user, max_new_tokens, temperature, top_p):
60
- return generate_once(system, user, {
61
- "max_new_tokens": int(max_new_tokens),
62
- "temperature": float(temperature),
63
- "top_p": float(top_p),
64
- })
65
-
66
- gr.Button("Generate").click(_ui, [sys_in, usr_in, max_new, temperature, top_p], out)
67
-
68
- # ---------- FastAPI JSON endpoint ----------
69
- api = FastAPI()
70
-
71
- @api.post("/generate")
72
- async def generate(req: Request):
73
- """
74
- POST JSON:
75
- {
76
- "prompt": "question with context...",
77
- "system": "system prompt (optional)",
78
- "params": { "max_new_tokens": 300, "temperature": 0.2, "top_p": 0.9 }
79
- }
80
- """
81
- body = await req.json()
82
- prompt = body["prompt"]
83
- system = body.get("system", "You are Granite, a helpful and concise assistant.")
84
- params = body.get("params", {})
85
- text = generate_once(system, prompt, params)
86
- return {"text": text}
87
 
88
- # Mount Gradio on "/"
89
- app = gr.mount_gradio_app(api, demo, path="/")
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+
4
+ model_id = "ibm-granite/granite-3.1-2b-instruct" # Hugging Face model ID
5
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
6
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
7
+
8
+ pipe = pipeline(
9
+ "text-generation",
10
+ model=model,
11
+ tokenizer=tokenizer,
12
+ max_length=512,
13
+ temperature=0.7
 
 
 
 
14
  )
15
 
16
+ def generate_answer(prompt):
17
+ result = pipe(prompt)[0]["generated_text"]
18
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ demo = gr.Interface(
21
+ fn=generate_answer,
22
+ inputs="text",
23
+ outputs="text",
24
+ title="Granite 3.1 2B Instruct - RAG Answering"
25
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ demo.launch()