Vladislav Krasnov commited on
Commit
cd00e73
·
1 Parent(s): 6627d48

Update space 11

Browse files
Files changed (1) hide show
  1. app.py +46 -22
app.py CHANGED
@@ -2,36 +2,48 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load model and tokenizer
6
- USERNAME = "sarekuwa"
7
- SPACE_NAME = "livecoder"
8
- API_ENDPOINT = f"https://{USERNAME}-{SPACE_NAME}.hf.space/api/predict"
9
-
10
- model_name = "microsoft/phi-2"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
12
- tokenizer.pad_token = tokenizer.eos_token
13
 
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- torch_dtype=torch.float32,
17
- device_map="cpu",
18
- trust_remote_code=True
19
- )
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def generate_response(message):
22
  """Process user input and generate response"""
23
  if not message.strip():
24
  return "Please enter a question."
25
 
 
 
 
26
  try:
27
- prompt = f"### Instruction: {message}\n### Response:"
 
28
 
29
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
30
 
 
31
  with torch.no_grad():
32
  outputs = model.generate(
33
  inputs.input_ids,
34
- max_new_tokens=256,
 
35
  temperature=0.7,
36
  do_sample=True,
37
  top_p=0.9,
@@ -43,10 +55,11 @@ def generate_response(message):
43
  return response.strip()
44
 
45
  except Exception as e:
46
- return f"Error generating response: {str(e)}"
47
 
 
48
  interface = gr.Interface(
49
- fn=generate_response, # Connect function to interface
50
  inputs=gr.Textbox(label="Input", placeholder="Enter programming question...", lines=3),
51
  outputs=gr.Textbox(label="Output", lines=10),
52
  title="LiveCoder API",
@@ -54,7 +67,18 @@ interface = gr.Interface(
54
  allow_flagging="never"
55
  )
56
 
57
- # Launch application
58
- interface.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
59
 
60
- print(f"API Endpoint: {API_ENDPOINT}")
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Use lighter model for CPU
6
+ # model_name = "microsoft/phi-2" # 2.7B - TOO HEAVY
7
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1B - much lighter
 
 
 
 
 
8
 
9
+ try:
10
+ print(f"Loading {model_name}...")
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ tokenizer.pad_token = tokenizer.eos_token
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype=torch.float32,
17
+ device_map="cpu",
18
+ low_cpu_mem_usage=True # Critical for CPU
19
+ )
20
+ print("Model loaded successfully")
21
+
22
+ except Exception as e:
23
+ print(f"Failed to load model: {e}")
24
+ # Fallback to dummy function
25
+ model, tokenizer = None, None
26
 
27
  def generate_response(message):
28
  """Process user input and generate response"""
29
  if not message.strip():
30
  return "Please enter a question."
31
 
32
+ if model is None or tokenizer is None:
33
+ return f"Model not loaded. Testing UI with: {message}"
34
+
35
  try:
36
+ # Format for chat model
37
+ prompt = f"<|user|>\n{message}\n<|assistant|>\n"
38
 
39
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=384)
40
 
41
+ # Generate with lower token count for CPU
42
  with torch.no_grad():
43
  outputs = model.generate(
44
  inputs.input_ids,
45
+ attention_mask=inputs.attention_mask, # FIX: Add attention mask
46
+ max_new_tokens=150, # Reduced for CPU
47
  temperature=0.7,
48
  do_sample=True,
49
  top_p=0.9,
 
55
  return response.strip()
56
 
57
  except Exception as e:
58
+ return f"Error: {str(e)[:100]}"
59
 
60
+ # Create interface
61
  interface = gr.Interface(
62
+ fn=generate_response,
63
  inputs=gr.Textbox(label="Input", placeholder="Enter programming question...", lines=3),
64
  outputs=gr.Textbox(label="Output", lines=10),
65
  title="LiveCoder API",
 
67
  allow_flagging="never"
68
  )
69
 
70
+ # API endpoint info
71
+ USERNAME = "sarekuwa"
72
+ SPACE_NAME = "livecoder"
73
+ print(f"API Endpoint: https://{USERNAME}-{SPACE_NAME}.hf.space/api/predict")
74
 
75
+ # CRITICAL: Enable queue for request processing
76
+ interface.queue(default_concurrency_limit=1)
77
+
78
+ # Launch application
79
+ interface.launch(
80
+ server_name="0.0.0.0",
81
+ server_port=7860,
82
+ share=False,
83
+ debug=False
84
+ )