Spaces:
Running
Running
| import spaces | |
| import gradio as gr | |
| from transformers import pipeline, TextIteratorStreamer | |
| import torch | |
| import threading | |
| import os | |
| # Load model and tokenizer | |
| model_name = os.getenv("MODEL_ID") | |
| pipe = pipeline("text-generation", model=model_name, device=0) | |
| tokenizer = pipe.tokenizer | |
| model = pipe.model | |
| # Fixed generation config | |
| MAX_TOKENS = 3000 | |
| TEMPERATURE = 0.1 | |
| TOP_P = 0.9 | |
| def respond_stream(summary, title, abstract): | |
| # Validate mandatory fields | |
| if not summary.strip() or not title.strip() or not abstract.strip(): | |
| return "❌ Error: PICOS Summary, Title, and Abstract are all required." | |
| # Build prompt | |
| prompt = ( | |
| f"Instruction: Use the following PICOS summary to evaluate the abstract.\n" | |
| f"\nPICOS Summary: {summary.strip()}" | |
| f"\n\nTitle: {title.strip()}\nAbstract: {abstract.strip()}" | |
| ) | |
| # Wrap into message for chat template | |
| messages = [{"role": "user", "content": prompt}] | |
| prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # Tokenize and prepare streamer | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| input_ids=inputs["input_ids"], | |
| streamer=streamer, | |
| max_new_tokens=MAX_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| partial_text = "" | |
| for token in streamer: | |
| partial_text += token | |
| yield partial_text | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Study design screener") | |
| with gr.Column(): | |
| summary = gr.Textbox(label="PICOS Summary", lines=4, placeholder="Required") | |
| title = gr.Textbox(label="Title", lines=2, placeholder="Required") | |
| abstract = gr.Textbox(label="Abstract", lines=10, placeholder="Required") | |
| output_box = gr.Textbox(label="Model Response", lines=15, interactive=False) | |
| generate_btn = gr.Button("Generate") | |
| generate_btn.click( | |
| fn=respond_stream, | |
| inputs=[summary, title, abstract], | |
| outputs=[output_box] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |