Spaces:
No application file
No application file
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import time | |
| import os | |
| # ββ CONFIG ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "14maddy/Agri_llm" | |
| MAX_NEW_TOKENS = 512 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # injected from Space Secrets | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loading model on: {DEVICE}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| ) | |
| print("Model loaded β") | |
| def run_inference(system_prompt: str, user_prompt: str, max_tokens: int, temperature: float): | |
| """Run a single prompt through Phi-3 and return response + metrics.""" | |
| if not user_prompt.strip(): | |
| return "β οΈ Please enter a prompt.", "", "" | |
| messages = [ | |
| {"role": "system", "content": system_prompt or "You are a helpful AI assistant."}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| # Phi-3 uses a special chat template | |
| formatted = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| start = time.time() | |
| output = pipe( | |
| formatted, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=temperature > 0, | |
| return_full_text=False, | |
| ) | |
| elapsed = time.time() - start | |
| response_text = output[0]["generated_text"].strip() | |
| # Token counts | |
| input_tokens = len(tokenizer.encode(formatted)) | |
| output_tokens = len(tokenizer.encode(response_text)) | |
| tokens_per_sec = round(output_tokens / elapsed, 1) if elapsed > 0 else 0 | |
| metrics = ( | |
| f"β± Latency: {elapsed:.2f}s | " | |
| f"π₯ Input tokens: {input_tokens} | " | |
| f"π€ Output tokens: {output_tokens} | " | |
| f"β‘ Speed: {tokens_per_sec} tok/s | " | |
| f"π₯ Device: {DEVICE.upper()}" | |
| ) | |
| return response_text, metrics | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Agri LLM Tester") as demo: | |
| gr.Markdown( | |
| """ | |
| # πΎ Agri LLM Β· Inference Tester | |
| **Model:** `14maddy/Agri_llm` β Fine-tuned Phi-3-mini for Agriculture | |
| Test prompts, measure latency, and evaluate model quality. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| system_box = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="You are a helpful AI assistant.", | |
| lines=2, | |
| ) | |
| user_box = gr.Textbox( | |
| label="User Prompt", | |
| placeholder="Ask anything...", | |
| lines=5, | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, maximum=1024, value=MAX_NEW_TOKENS, step=32, | |
| label="Max New Tokens", | |
| ) | |
| temp_slider = gr.Slider( | |
| minimum=0.0, maximum=1.5, value=0.7, step=0.05, | |
| label="Temperature (0 = greedy)", | |
| ) | |
| run_btn = gr.Button("βΆ Run", variant="primary") | |
| clear_btn = gr.Button("π Clear") | |
| with gr.Column(scale=3): | |
| output_box = gr.Textbox( | |
| label="Model Response", | |
| lines=14, | |
| interactive=False, | |
| ) | |
| metrics_box = gr.Textbox( | |
| label="Performance Metrics", | |
| interactive=False, | |
| lines=2, | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["You are an agriculture AI assistant.", "What are the signs of nitrogen deficiency in soil?", 256, 0.7], | |
| ["You are a helpful coding assistant.", "Write a Python function to calculate soil pH from titration data.", 400, 0.5], | |
| ["", "Explain transformer attention mechanism in simple terms.", 300, 0.8], | |
| ], | |
| inputs=[system_box, user_box, max_tokens_slider, temp_slider], | |
| label="Quick Test Examples", | |
| ) | |
| # Wire up buttons | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[system_box, user_box, max_tokens_slider, temp_slider], | |
| outputs=[output_box, metrics_box], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", "", ""), | |
| inputs=[], | |
| outputs=[system_box, user_box, output_box, metrics_box], | |
| ) | |
| demo.launch() |